123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504 |
- package ws
- import (
- "bufio"
- "bytes"
- "io"
- "net/http"
- "net/textproto"
- "net/url"
- "strconv"
- "github.com/gobwas/httphead"
- )
- const (
- crlf = "\r\n"
- colonAndSpace = ": "
- commaAndSpace = ", "
- )
- const (
- textHeadUpgrade = "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n"
- )
- var (
- textHeadBadRequest = statusText(http.StatusBadRequest)
- textHeadInternalServerError = statusText(http.StatusInternalServerError)
- textHeadUpgradeRequired = statusText(http.StatusUpgradeRequired)
- textTailErrHandshakeBadProtocol = errorText(ErrHandshakeBadProtocol)
- textTailErrHandshakeBadMethod = errorText(ErrHandshakeBadMethod)
- textTailErrHandshakeBadHost = errorText(ErrHandshakeBadHost)
- textTailErrHandshakeBadUpgrade = errorText(ErrHandshakeBadUpgrade)
- textTailErrHandshakeBadConnection = errorText(ErrHandshakeBadConnection)
- textTailErrHandshakeBadSecAccept = errorText(ErrHandshakeBadSecAccept)
- textTailErrHandshakeBadSecKey = errorText(ErrHandshakeBadSecKey)
- textTailErrHandshakeBadSecVersion = errorText(ErrHandshakeBadSecVersion)
- textTailErrUpgradeRequired = errorText(ErrHandshakeUpgradeRequired)
- )
- var (
- headerHost = "Host"
- headerUpgrade = "Upgrade"
- headerConnection = "Connection"
- headerSecVersion = "Sec-WebSocket-Version"
- headerSecProtocol = "Sec-WebSocket-Protocol"
- headerSecExtensions = "Sec-WebSocket-Extensions"
- headerSecKey = "Sec-WebSocket-Key"
- headerSecAccept = "Sec-WebSocket-Accept"
- headerHostCanonical = textproto.CanonicalMIMEHeaderKey(headerHost)
- headerUpgradeCanonical = textproto.CanonicalMIMEHeaderKey(headerUpgrade)
- headerConnectionCanonical = textproto.CanonicalMIMEHeaderKey(headerConnection)
- headerSecVersionCanonical = textproto.CanonicalMIMEHeaderKey(headerSecVersion)
- headerSecProtocolCanonical = textproto.CanonicalMIMEHeaderKey(headerSecProtocol)
- headerSecExtensionsCanonical = textproto.CanonicalMIMEHeaderKey(headerSecExtensions)
- headerSecKeyCanonical = textproto.CanonicalMIMEHeaderKey(headerSecKey)
- headerSecAcceptCanonical = textproto.CanonicalMIMEHeaderKey(headerSecAccept)
- )
- var (
- specHeaderValueUpgrade = []byte("websocket")
- specHeaderValueConnection = []byte("Upgrade")
- specHeaderValueConnectionLower = []byte("upgrade")
- specHeaderValueSecVersion = []byte("13")
- )
- var (
- httpVersion1_0 = []byte("HTTP/1.0")
- httpVersion1_1 = []byte("HTTP/1.1")
- httpVersionPrefix = []byte("HTTP/")
- )
- type httpRequestLine struct {
- method, uri []byte
- major, minor int
- }
- type httpResponseLine struct {
- major, minor int
- status int
- reason []byte
- }
- // httpParseRequestLine parses http request line like "GET / HTTP/1.0".
- func httpParseRequestLine(line []byte) (req httpRequestLine, err error) {
- var proto []byte
- req.method, req.uri, proto = bsplit3(line, ' ')
- var ok bool
- req.major, req.minor, ok = httpParseVersion(proto)
- if !ok {
- err = ErrMalformedRequest
- return
- }
- return
- }
- func httpParseResponseLine(line []byte) (resp httpResponseLine, err error) {
- var (
- proto []byte
- status []byte
- )
- proto, status, resp.reason = bsplit3(line, ' ')
- var ok bool
- resp.major, resp.minor, ok = httpParseVersion(proto)
- if !ok {
- return resp, ErrMalformedResponse
- }
- var convErr error
- resp.status, convErr = asciiToInt(status)
- if convErr != nil {
- return resp, ErrMalformedResponse
- }
- return resp, nil
- }
- // httpParseVersion parses major and minor version of HTTP protocol. It returns
- // parsed values and true if parse is ok.
- func httpParseVersion(bts []byte) (major, minor int, ok bool) {
- switch {
- case bytes.Equal(bts, httpVersion1_0):
- return 1, 0, true
- case bytes.Equal(bts, httpVersion1_1):
- return 1, 1, true
- case len(bts) < 8:
- return
- case !bytes.Equal(bts[:5], httpVersionPrefix):
- return
- }
- bts = bts[5:]
- dot := bytes.IndexByte(bts, '.')
- if dot == -1 {
- return
- }
- var err error
- major, err = asciiToInt(bts[:dot])
- if err != nil {
- return
- }
- minor, err = asciiToInt(bts[dot+1:])
- if err != nil {
- return
- }
- return major, minor, true
- }
- // httpParseHeaderLine parses HTTP header as key-value pair. It returns parsed
- // values and true if parse is ok.
- func httpParseHeaderLine(line []byte) (k, v []byte, ok bool) {
- colon := bytes.IndexByte(line, ':')
- if colon == -1 {
- return
- }
- k = btrim(line[:colon])
- // TODO(gobwas): maybe use just lower here?
- canonicalizeHeaderKey(k)
- v = btrim(line[colon+1:])
- return k, v, true
- }
- // httpGetHeader is the same as textproto.MIMEHeader.Get, except the thing,
- // that key is already canonical. This helps to increase performance.
- func httpGetHeader(h http.Header, key string) string {
- if h == nil {
- return ""
- }
- v := h[key]
- if len(v) == 0 {
- return ""
- }
- return v[0]
- }
- // The request MAY include a header field with the name
- // |Sec-WebSocket-Protocol|. If present, this value indicates one or more
- // comma-separated subprotocol the client wishes to speak, ordered by
- // preference. The elements that comprise this value MUST be non-empty strings
- // with characters in the range U+0021 to U+007E not including separator
- // characters as defined in [RFC2616] and MUST all be unique strings. The ABNF
- // for the value of this header field is 1#token, where the definitions of
- // constructs and rules are as given in [RFC2616].
- func strSelectProtocol(h string, check func(string) bool) (ret string, ok bool) {
- ok = httphead.ScanTokens(strToBytes(h), func(v []byte) bool {
- if check(btsToString(v)) {
- ret = string(v)
- return false
- }
- return true
- })
- return
- }
- func btsSelectProtocol(h []byte, check func([]byte) bool) (ret string, ok bool) {
- var selected []byte
- ok = httphead.ScanTokens(h, func(v []byte) bool {
- if check(v) {
- selected = v
- return false
- }
- return true
- })
- if ok && selected != nil {
- return string(selected), true
- }
- return
- }
- func btsSelectExtensions(h []byte, selected []httphead.Option, check func(httphead.Option) bool) ([]httphead.Option, bool) {
- s := httphead.OptionSelector{
- Flags: httphead.SelectCopy,
- Check: check,
- }
- return s.Select(h, selected)
- }
- func negotiateMaybe(in httphead.Option, dest []httphead.Option, f func(httphead.Option) (httphead.Option, error)) ([]httphead.Option, error) {
- if in.Size() == 0 {
- return dest, nil
- }
- opt, err := f(in)
- if err != nil {
- return nil, err
- }
- if opt.Size() > 0 {
- dest = append(dest, opt)
- }
- return dest, nil
- }
- func negotiateExtensions(
- h []byte, dest []httphead.Option,
- f func(httphead.Option) (httphead.Option, error),
- ) (_ []httphead.Option, err error) {
- index := -1
- var current httphead.Option
- ok := httphead.ScanOptions(h, func(i int, name, attr, val []byte) httphead.Control {
- if i != index {
- dest, err = negotiateMaybe(current, dest, f)
- if err != nil {
- return httphead.ControlBreak
- }
- index = i
- current = httphead.Option{Name: name}
- }
- if attr != nil {
- current.Parameters.Set(attr, val)
- }
- return httphead.ControlContinue
- })
- if !ok {
- return nil, ErrMalformedRequest
- }
- return negotiateMaybe(current, dest, f)
- }
- func httpWriteHeader(bw *bufio.Writer, key, value string) {
- httpWriteHeaderKey(bw, key)
- bw.WriteString(value)
- bw.WriteString(crlf)
- }
- func httpWriteHeaderBts(bw *bufio.Writer, key string, value []byte) {
- httpWriteHeaderKey(bw, key)
- bw.Write(value)
- bw.WriteString(crlf)
- }
- func httpWriteHeaderKey(bw *bufio.Writer, key string) {
- bw.WriteString(key)
- bw.WriteString(colonAndSpace)
- }
- func httpWriteUpgradeRequest(
- bw *bufio.Writer,
- u *url.URL,
- nonce []byte,
- protocols []string,
- extensions []httphead.Option,
- header HandshakeHeader,
- ) {
- bw.WriteString("GET ")
- bw.WriteString(u.RequestURI())
- bw.WriteString(" HTTP/1.1\r\n")
- httpWriteHeader(bw, headerHost, u.Host)
- httpWriteHeaderBts(bw, headerUpgrade, specHeaderValueUpgrade)
- httpWriteHeaderBts(bw, headerConnection, specHeaderValueConnection)
- httpWriteHeaderBts(bw, headerSecVersion, specHeaderValueSecVersion)
- // NOTE: write nonce bytes as a string to prevent heap allocation –
- // WriteString() copy given string into its inner buffer, unlike Write()
- // which may write p directly to the underlying io.Writer – which in turn
- // will lead to p escape.
- httpWriteHeader(bw, headerSecKey, btsToString(nonce))
- if len(protocols) > 0 {
- httpWriteHeaderKey(bw, headerSecProtocol)
- for i, p := range protocols {
- if i > 0 {
- bw.WriteString(commaAndSpace)
- }
- bw.WriteString(p)
- }
- bw.WriteString(crlf)
- }
- if len(extensions) > 0 {
- httpWriteHeaderKey(bw, headerSecExtensions)
- httphead.WriteOptions(bw, extensions)
- bw.WriteString(crlf)
- }
- if header != nil {
- header.WriteTo(bw)
- }
- bw.WriteString(crlf)
- }
- func httpWriteResponseUpgrade(bw *bufio.Writer, nonce []byte, hs Handshake, header HandshakeHeaderFunc) {
- bw.WriteString(textHeadUpgrade)
- httpWriteHeaderKey(bw, headerSecAccept)
- writeAccept(bw, nonce)
- bw.WriteString(crlf)
- if hs.Protocol != "" {
- httpWriteHeader(bw, headerSecProtocol, hs.Protocol)
- }
- if len(hs.Extensions) > 0 {
- httpWriteHeaderKey(bw, headerSecExtensions)
- httphead.WriteOptions(bw, hs.Extensions)
- bw.WriteString(crlf)
- }
- if header != nil {
- header(bw)
- }
- bw.WriteString(crlf)
- }
- func httpWriteResponseError(bw *bufio.Writer, err error, code int, header HandshakeHeaderFunc) {
- switch code {
- case http.StatusBadRequest:
- bw.WriteString(textHeadBadRequest)
- case http.StatusInternalServerError:
- bw.WriteString(textHeadInternalServerError)
- case http.StatusUpgradeRequired:
- bw.WriteString(textHeadUpgradeRequired)
- default:
- writeStatusText(bw, code)
- }
- // Write custom headers.
- if header != nil {
- header(bw)
- }
- switch err {
- case ErrHandshakeBadProtocol:
- bw.WriteString(textTailErrHandshakeBadProtocol)
- case ErrHandshakeBadMethod:
- bw.WriteString(textTailErrHandshakeBadMethod)
- case ErrHandshakeBadHost:
- bw.WriteString(textTailErrHandshakeBadHost)
- case ErrHandshakeBadUpgrade:
- bw.WriteString(textTailErrHandshakeBadUpgrade)
- case ErrHandshakeBadConnection:
- bw.WriteString(textTailErrHandshakeBadConnection)
- case ErrHandshakeBadSecAccept:
- bw.WriteString(textTailErrHandshakeBadSecAccept)
- case ErrHandshakeBadSecKey:
- bw.WriteString(textTailErrHandshakeBadSecKey)
- case ErrHandshakeBadSecVersion:
- bw.WriteString(textTailErrHandshakeBadSecVersion)
- case ErrHandshakeUpgradeRequired:
- bw.WriteString(textTailErrUpgradeRequired)
- case nil:
- bw.WriteString(crlf)
- default:
- writeErrorText(bw, err)
- }
- }
- func writeStatusText(bw *bufio.Writer, code int) {
- bw.WriteString("HTTP/1.1 ")
- bw.WriteString(strconv.Itoa(code))
- bw.WriteByte(' ')
- bw.WriteString(http.StatusText(code))
- bw.WriteString(crlf)
- bw.WriteString("Content-Type: text/plain; charset=utf-8")
- bw.WriteString(crlf)
- }
- func writeErrorText(bw *bufio.Writer, err error) {
- body := err.Error()
- bw.WriteString("Content-Length: ")
- bw.WriteString(strconv.Itoa(len(body)))
- bw.WriteString(crlf)
- bw.WriteString(crlf)
- bw.WriteString(body)
- }
- // httpError is like the http.Error with WebSocket context exception.
- func httpError(w http.ResponseWriter, body string, code int) {
- w.Header().Set("Content-Type", "text/plain; charset=utf-8")
- w.Header().Set("Content-Length", strconv.Itoa(len(body)))
- w.WriteHeader(code)
- w.Write([]byte(body))
- }
- // statusText is a non-performant status text generator.
- // NOTE: Used only to generate constants.
- func statusText(code int) string {
- var buf bytes.Buffer
- bw := bufio.NewWriter(&buf)
- writeStatusText(bw, code)
- bw.Flush()
- return buf.String()
- }
- // errorText is a non-performant error text generator.
- // NOTE: Used only to generate constants.
- func errorText(err error) string {
- var buf bytes.Buffer
- bw := bufio.NewWriter(&buf)
- writeErrorText(bw, err)
- bw.Flush()
- return buf.String()
- }
- // HandshakeHeader is the interface that writes both upgrade request or
- // response headers into a given io.Writer.
- type HandshakeHeader interface {
- io.WriterTo
- }
- // HandshakeHeaderString is an adapter to allow the use of headers represented
- // by ordinary string as HandshakeHeader.
- type HandshakeHeaderString string
- // WriteTo implements HandshakeHeader (and io.WriterTo) interface.
- func (s HandshakeHeaderString) WriteTo(w io.Writer) (int64, error) {
- n, err := io.WriteString(w, string(s))
- return int64(n), err
- }
- // HandshakeHeaderBytes is an adapter to allow the use of headers represented
- // by ordinary slice of bytes as HandshakeHeader.
- type HandshakeHeaderBytes []byte
- // WriteTo implements HandshakeHeader (and io.WriterTo) interface.
- func (b HandshakeHeaderBytes) WriteTo(w io.Writer) (int64, error) {
- n, err := w.Write(b)
- return int64(n), err
- }
- // HandshakeHeaderFunc is an adapter to allow the use of headers represented by
- // ordinary function as HandshakeHeader.
- type HandshakeHeaderFunc func(io.Writer) (int64, error)
- // WriteTo implements HandshakeHeader (and io.WriterTo) interface.
- func (f HandshakeHeaderFunc) WriteTo(w io.Writer) (int64, error) {
- return f(w)
- }
- // HandshakeHeaderHTTP is an adapter to allow the use of http.Header as
- // HandshakeHeader.
- type HandshakeHeaderHTTP http.Header
- // WriteTo implements HandshakeHeader (and io.WriterTo) interface.
- func (h HandshakeHeaderHTTP) WriteTo(w io.Writer) (int64, error) {
- wr := writer{w: w}
- err := http.Header(h).Write(&wr)
- return wr.n, err
- }
- type writer struct {
- n int64
- w io.Writer
- }
- func (w *writer) WriteString(s string) (int, error) {
- n, err := io.WriteString(w.w, s)
- w.n += int64(n)
- return n, err
- }
- func (w *writer) Write(p []byte) (int, error) {
- n, err := w.w.Write(p)
- w.n += int64(n)
- return n, err
- }
|