check.go 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. package ws
  2. import "unicode/utf8"
  3. // State represents state of websocket endpoint.
  4. // It used by some functions to be more strict when checking compatibility with RFC6455.
  5. type State uint8
  6. const (
  7. // StateServerSide means that endpoint (caller) is a server.
  8. StateServerSide State = 0x1 << iota
  9. // StateClientSide means that endpoint (caller) is a client.
  10. StateClientSide
  11. // StateExtended means that extension was negotiated during handshake.
  12. StateExtended
  13. // StateFragmented means that endpoint (caller) has received fragmented
  14. // frame and waits for continuation parts.
  15. StateFragmented
  16. )
  17. // Is checks whether the s has v enabled.
  18. func (s State) Is(v State) bool {
  19. return uint8(s)&uint8(v) != 0
  20. }
  21. // Set enables v state on s.
  22. func (s State) Set(v State) State {
  23. return s | v
  24. }
  25. // Clear disables v state on s.
  26. func (s State) Clear(v State) State {
  27. return s & (^v)
  28. }
  29. // ServerSide reports whether states represents server side.
  30. func (s State) ServerSide() bool { return s.Is(StateServerSide) }
  31. // ClientSide reports whether state represents client side.
  32. func (s State) ClientSide() bool { return s.Is(StateClientSide) }
  33. // Extended reports whether state is extended.
  34. func (s State) Extended() bool { return s.Is(StateExtended) }
  35. // Fragmented reports whether state is fragmented.
  36. func (s State) Fragmented() bool { return s.Is(StateFragmented) }
  37. // ProtocolError describes error during checking/parsing websocket frames or
  38. // headers.
  39. type ProtocolError string
  40. // Error implements error interface.
  41. func (p ProtocolError) Error() string { return string(p) }
  42. // Errors used by the protocol checkers.
  43. var (
  44. ErrProtocolOpCodeReserved = ProtocolError("use of reserved op code")
  45. ErrProtocolControlPayloadOverflow = ProtocolError("control frame payload limit exceeded")
  46. ErrProtocolControlNotFinal = ProtocolError("control frame is not final")
  47. ErrProtocolNonZeroRsv = ProtocolError("non-zero rsv bits with no extension negotiated")
  48. ErrProtocolMaskRequired = ProtocolError("frames from client to server must be masked")
  49. ErrProtocolMaskUnexpected = ProtocolError("frames from server to client must be not masked")
  50. ErrProtocolContinuationExpected = ProtocolError("unexpected non-continuation data frame")
  51. ErrProtocolContinuationUnexpected = ProtocolError("unexpected continuation data frame")
  52. ErrProtocolStatusCodeNotInUse = ProtocolError("status code is not in use")
  53. ErrProtocolStatusCodeApplicationLevel = ProtocolError("status code is only application level")
  54. ErrProtocolStatusCodeNoMeaning = ProtocolError("status code has no meaning yet")
  55. ErrProtocolStatusCodeUnknown = ProtocolError("status code is not defined in spec")
  56. ErrProtocolInvalidUTF8 = ProtocolError("invalid utf8 sequence in close reason")
  57. )
  58. // CheckHeader checks h to contain valid header data for given state s.
  59. //
  60. // Note that zero state (0) means that state is clean,
  61. // neither server or client side, nor fragmented, nor extended.
  62. func CheckHeader(h Header, s State) error {
  63. if h.OpCode.IsReserved() {
  64. return ErrProtocolOpCodeReserved
  65. }
  66. if h.OpCode.IsControl() {
  67. if h.Length > MaxControlFramePayloadSize {
  68. return ErrProtocolControlPayloadOverflow
  69. }
  70. if !h.Fin {
  71. return ErrProtocolControlNotFinal
  72. }
  73. }
  74. switch {
  75. // [RFC6455]: MUST be 0 unless an extension is negotiated that defines meanings for
  76. // non-zero values. If a nonzero value is received and none of the
  77. // negotiated extensions defines the meaning of such a nonzero value, the
  78. // receiving endpoint MUST _Fail the WebSocket Connection_.
  79. case h.Rsv != 0 && !s.Extended():
  80. return ErrProtocolNonZeroRsv
  81. // [RFC6455]: The server MUST close the connection upon receiving a frame that is not masked.
  82. // In this case, a server MAY send a Close frame with a status code of 1002 (protocol error)
  83. // as defined in Section 7.4.1. A server MUST NOT mask any frames that it sends to the client.
  84. // A client MUST close a connection if it detects a masked frame. In this case, it MAY use the
  85. // status code 1002 (protocol error) as defined in Section 7.4.1.
  86. case s.ServerSide() && !h.Masked:
  87. return ErrProtocolMaskRequired
  88. case s.ClientSide() && h.Masked:
  89. return ErrProtocolMaskUnexpected
  90. // [RFC6455]: See detailed explanation in 5.4 section.
  91. case s.Fragmented() && !h.OpCode.IsControl() && h.OpCode != OpContinuation:
  92. return ErrProtocolContinuationExpected
  93. case !s.Fragmented() && h.OpCode == OpContinuation:
  94. return ErrProtocolContinuationUnexpected
  95. default:
  96. return nil
  97. }
  98. }
  99. // CheckCloseFrameData checks received close information
  100. // to be valid RFC6455 compatible close info.
  101. //
  102. // Note that code.Empty() or code.IsAppLevel() will raise error.
  103. //
  104. // If endpoint sends close frame without status code (with frame.Length = 0),
  105. // application should not check its payload.
  106. func CheckCloseFrameData(code StatusCode, reason string) error {
  107. switch {
  108. case code.IsNotUsed():
  109. return ErrProtocolStatusCodeNotInUse
  110. case code.IsProtocolReserved():
  111. return ErrProtocolStatusCodeApplicationLevel
  112. case code == StatusNoMeaningYet:
  113. return ErrProtocolStatusCodeNoMeaning
  114. case code.IsProtocolSpec() && !code.IsProtocolDefined():
  115. return ErrProtocolStatusCodeUnknown
  116. case !utf8.ValidString(reason):
  117. return ErrProtocolInvalidUTF8
  118. default:
  119. return nil
  120. }
  121. }