reader.go 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
  1. package wsutil
  2. import (
  3. "errors"
  4. "io"
  5. "io/ioutil"
  6. "github.com/gobwas/ws"
  7. )
  8. // ErrNoFrameAdvance means that Reader's Read() method was called without
  9. // preceding NextFrame() call.
  10. var ErrNoFrameAdvance = errors.New("no frame advance")
  11. // FrameHandlerFunc handles parsed frame header and its body represented by
  12. // io.Reader.
  13. //
  14. // Note that reader represents already unmasked body.
  15. type FrameHandlerFunc func(ws.Header, io.Reader) error
  16. // Reader is a wrapper around source io.Reader which represents WebSocket
  17. // connection. It contains options for reading messages from source.
  18. //
  19. // Reader implements io.Reader, which Read() method reads payload of incoming
  20. // WebSocket frames. It also takes care on fragmented frames and possibly
  21. // intermediate control frames between them.
  22. //
  23. // Note that Reader's methods are not goroutine safe.
  24. type Reader struct {
  25. Source io.Reader
  26. State ws.State
  27. // SkipHeaderCheck disables checking header bits to be RFC6455 compliant.
  28. SkipHeaderCheck bool
  29. // CheckUTF8 enables UTF-8 checks for text frames payload. If incoming
  30. // bytes are not valid UTF-8 sequence, ErrInvalidUTF8 returned.
  31. CheckUTF8 bool
  32. // Extensions is a list of negotiated extensions for reader Source.
  33. // It is used to meet the specs and clear appropriate bits in fragment
  34. // header RSV segment.
  35. Extensions []RecvExtension
  36. // TODO(gobwas): add max frame size limit here.
  37. OnContinuation FrameHandlerFunc
  38. OnIntermediate FrameHandlerFunc
  39. opCode ws.OpCode // Used to store message op code on fragmentation.
  40. frame io.Reader // Used to as frame reader.
  41. raw io.LimitedReader // Used to discard frames without cipher.
  42. utf8 UTF8Reader // Used to check UTF8 sequences if CheckUTF8 is true.
  43. fseq int // Fragment sequence in message counter.
  44. }
  45. // NewReader creates new frame reader that reads from r keeping given state to
  46. // make some protocol validity checks when it needed.
  47. func NewReader(r io.Reader, s ws.State) *Reader {
  48. return &Reader{
  49. Source: r,
  50. State: s,
  51. }
  52. }
  53. // NewClientSideReader is a helper function that calls NewReader with r and
  54. // ws.StateClientSide.
  55. func NewClientSideReader(r io.Reader) *Reader {
  56. return NewReader(r, ws.StateClientSide)
  57. }
  58. // NewServerSideReader is a helper function that calls NewReader with r and
  59. // ws.StateServerSide.
  60. func NewServerSideReader(r io.Reader) *Reader {
  61. return NewReader(r, ws.StateServerSide)
  62. }
  63. // Read implements io.Reader. It reads the next message payload into p.
  64. // It takes care on fragmented messages.
  65. //
  66. // The error is io.EOF only if all of message bytes were read.
  67. // If an io.EOF happens during reading some but not all the message bytes
  68. // Read() returns io.ErrUnexpectedEOF.
  69. //
  70. // The error is ErrNoFrameAdvance if no NextFrame() call was made before
  71. // reading next message bytes.
  72. func (r *Reader) Read(p []byte) (n int, err error) {
  73. if r.frame == nil {
  74. if !r.fragmented() {
  75. // Every new Read() must be preceded by NextFrame() call.
  76. return 0, ErrNoFrameAdvance
  77. }
  78. // Read next continuation or intermediate control frame.
  79. _, err := r.NextFrame()
  80. if err != nil {
  81. return 0, err
  82. }
  83. if r.frame == nil {
  84. // We handled intermediate control and now got nothing to read.
  85. return 0, nil
  86. }
  87. }
  88. n, err = r.frame.Read(p)
  89. if err != nil && err != io.EOF {
  90. return
  91. }
  92. if err == nil && r.raw.N != 0 {
  93. return n, nil
  94. }
  95. // EOF condition (either err is io.EOF or r.raw.N is zero).
  96. switch {
  97. case r.raw.N != 0:
  98. err = io.ErrUnexpectedEOF
  99. case r.fragmented():
  100. err = nil
  101. r.resetFragment()
  102. case r.CheckUTF8 && !r.utf8.Valid():
  103. // NOTE: check utf8 only when full message received, since partial
  104. // reads may be invalid.
  105. n = r.utf8.Accepted()
  106. err = ErrInvalidUTF8
  107. default:
  108. r.reset()
  109. err = io.EOF
  110. }
  111. return
  112. }
  113. // Discard discards current message unread bytes.
  114. // It discards all frames of fragmented message.
  115. func (r *Reader) Discard() (err error) {
  116. for {
  117. _, err = io.Copy(ioutil.Discard, &r.raw)
  118. if err != nil {
  119. break
  120. }
  121. if !r.fragmented() {
  122. break
  123. }
  124. if _, err = r.NextFrame(); err != nil {
  125. break
  126. }
  127. }
  128. r.reset()
  129. return err
  130. }
  131. // NextFrame prepares r to read next message. It returns received frame header
  132. // and non-nil error on failure.
  133. //
  134. // Note that next NextFrame() call must be done after receiving or discarding
  135. // all current message bytes.
  136. func (r *Reader) NextFrame() (hdr ws.Header, err error) {
  137. hdr, err = ws.ReadHeader(r.Source)
  138. if err == io.EOF && r.fragmented() {
  139. // If we are in fragmented state EOF means that is was totally
  140. // unexpected.
  141. //
  142. // NOTE: This is necessary to prevent callers such that
  143. // ioutil.ReadAll to receive some amount of bytes without an error.
  144. // ReadAll() ignores an io.EOF error, thus caller may think that
  145. // whole message fetched, but actually only part of it.
  146. err = io.ErrUnexpectedEOF
  147. }
  148. if err == nil && !r.SkipHeaderCheck {
  149. err = ws.CheckHeader(hdr, r.State)
  150. }
  151. if err != nil {
  152. return hdr, err
  153. }
  154. // Save raw reader to use it on discarding frame without ciphering and
  155. // other streaming checks.
  156. r.raw = io.LimitedReader{
  157. R: r.Source,
  158. N: hdr.Length,
  159. }
  160. frame := io.Reader(&r.raw)
  161. if hdr.Masked {
  162. frame = NewCipherReader(frame, hdr.Mask)
  163. }
  164. for _, ext := range r.Extensions {
  165. hdr.Rsv, err = ext.BitsRecv(r.fseq, hdr.Rsv)
  166. if err != nil {
  167. return hdr, err
  168. }
  169. }
  170. if r.fragmented() {
  171. if hdr.OpCode.IsControl() {
  172. if cb := r.OnIntermediate; cb != nil {
  173. err = cb(hdr, frame)
  174. }
  175. if err == nil {
  176. // Ensure that src is empty.
  177. _, err = io.Copy(ioutil.Discard, &r.raw)
  178. }
  179. return
  180. }
  181. } else {
  182. r.opCode = hdr.OpCode
  183. }
  184. if r.CheckUTF8 && (hdr.OpCode == ws.OpText || (r.fragmented() && r.opCode == ws.OpText)) {
  185. r.utf8.Source = frame
  186. frame = &r.utf8
  187. }
  188. // Save reader with ciphering and other streaming checks.
  189. r.frame = frame
  190. if hdr.OpCode == ws.OpContinuation {
  191. if cb := r.OnContinuation; cb != nil {
  192. err = cb(hdr, frame)
  193. }
  194. }
  195. if hdr.Fin {
  196. r.State = r.State.Clear(ws.StateFragmented)
  197. r.fseq = 0
  198. } else {
  199. r.State = r.State.Set(ws.StateFragmented)
  200. r.fseq++
  201. }
  202. return
  203. }
  204. func (r *Reader) fragmented() bool {
  205. return r.State.Fragmented()
  206. }
  207. func (r *Reader) resetFragment() {
  208. r.raw = io.LimitedReader{}
  209. r.frame = nil
  210. // Reset source of the UTF8Reader, but not the state.
  211. r.utf8.Source = nil
  212. }
  213. func (r *Reader) reset() {
  214. r.raw = io.LimitedReader{}
  215. r.frame = nil
  216. r.utf8 = UTF8Reader{}
  217. r.fseq = 0
  218. r.opCode = 0
  219. }
  220. // NextReader prepares next message read from r. It returns header that
  221. // describes the message and io.Reader to read message's payload. It returns
  222. // non-nil error when it is not possible to read message's initial frame.
  223. //
  224. // Note that next NextReader() on the same r should be done after reading all
  225. // bytes from previously returned io.Reader. For more performant way to discard
  226. // message use Reader and its Discard() method.
  227. //
  228. // Note that it will not handle any "intermediate" frames, that possibly could
  229. // be received between text/binary continuation frames. That is, if peer sent
  230. // text/binary frame with fin flag "false", then it could send ping frame, and
  231. // eventually remaining part of text/binary frame with fin "true" – with
  232. // NextReader() the ping frame will be dropped without any notice. To handle
  233. // this rare, but possible situation (and if you do not know exactly which
  234. // frames peer could send), you could use Reader with OnIntermediate field set.
  235. func NextReader(r io.Reader, s ws.State) (ws.Header, io.Reader, error) {
  236. rd := &Reader{
  237. Source: r,
  238. State: s,
  239. }
  240. header, err := rd.NextFrame()
  241. if err != nil {
  242. return header, nil, err
  243. }
  244. return header, rd, nil
  245. }