datachannel.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403
  1. // Package datachannel implements WebRTC Data Channels
  2. package datachannel
  3. import (
  4. "errors"
  5. "fmt"
  6. "io"
  7. "sync"
  8. "sync/atomic"
  9. "time"
  10. "github.com/pion/logging"
  11. "github.com/pion/sctp"
  12. )
  13. const receiveMTU = 8192
  14. // Reader is an extended io.Reader
  15. // that also returns if the message is text.
  16. type Reader interface {
  17. ReadDataChannel([]byte) (int, bool, error)
  18. }
  19. // ReadDeadliner extends an io.Reader to expose setting a read deadline.
  20. type ReadDeadliner interface {
  21. SetReadDeadline(time.Time) error
  22. }
  23. // Writer is an extended io.Writer
  24. // that also allows indicating if a message is text.
  25. type Writer interface {
  26. WriteDataChannel([]byte, bool) (int, error)
  27. }
  28. // ReadWriteCloser is an extended io.ReadWriteCloser
  29. // that also implements our Reader and Writer.
  30. type ReadWriteCloser interface {
  31. io.Reader
  32. io.Writer
  33. Reader
  34. Writer
  35. io.Closer
  36. }
  37. // DataChannel represents a data channel
  38. type DataChannel struct {
  39. Config
  40. // stats
  41. messagesSent uint32
  42. messagesReceived uint32
  43. bytesSent uint64
  44. bytesReceived uint64
  45. mu sync.Mutex
  46. onOpenCompleteHandler func()
  47. openCompleteHandlerOnce sync.Once
  48. stream *sctp.Stream
  49. log logging.LeveledLogger
  50. }
  51. // Config is used to configure the data channel.
  52. type Config struct {
  53. ChannelType ChannelType
  54. Negotiated bool
  55. Priority uint16
  56. ReliabilityParameter uint32
  57. Label string
  58. Protocol string
  59. LoggerFactory logging.LoggerFactory
  60. }
  61. func newDataChannel(stream *sctp.Stream, config *Config) (*DataChannel, error) {
  62. return &DataChannel{
  63. Config: *config,
  64. stream: stream,
  65. log: config.LoggerFactory.NewLogger("datachannel"),
  66. }, nil
  67. }
  68. // Dial opens a data channels over SCTP
  69. func Dial(a *sctp.Association, id uint16, config *Config) (*DataChannel, error) {
  70. stream, err := a.OpenStream(id, sctp.PayloadTypeWebRTCBinary)
  71. if err != nil {
  72. return nil, err
  73. }
  74. dc, err := Client(stream, config)
  75. if err != nil {
  76. return nil, err
  77. }
  78. return dc, nil
  79. }
  80. // Client opens a data channel over an SCTP stream
  81. func Client(stream *sctp.Stream, config *Config) (*DataChannel, error) {
  82. msg := &channelOpen{
  83. ChannelType: config.ChannelType,
  84. Priority: config.Priority,
  85. ReliabilityParameter: config.ReliabilityParameter,
  86. Label: []byte(config.Label),
  87. Protocol: []byte(config.Protocol),
  88. }
  89. if !config.Negotiated {
  90. rawMsg, err := msg.Marshal()
  91. if err != nil {
  92. return nil, fmt.Errorf("failed to marshal ChannelOpen %w", err)
  93. }
  94. if _, err = stream.WriteSCTP(rawMsg, sctp.PayloadTypeWebRTCDCEP); err != nil {
  95. return nil, fmt.Errorf("failed to send ChannelOpen %w", err)
  96. }
  97. }
  98. return newDataChannel(stream, config)
  99. }
  100. // Accept is used to accept incoming data channels over SCTP
  101. func Accept(a *sctp.Association, config *Config, existingChannels ...*DataChannel) (*DataChannel, error) {
  102. stream, err := a.AcceptStream()
  103. if err != nil {
  104. return nil, err
  105. }
  106. for _, ch := range existingChannels {
  107. if ch.StreamIdentifier() == stream.StreamIdentifier() {
  108. ch.stream.SetDefaultPayloadType(sctp.PayloadTypeWebRTCBinary)
  109. return ch, nil
  110. }
  111. }
  112. stream.SetDefaultPayloadType(sctp.PayloadTypeWebRTCBinary)
  113. dc, err := Server(stream, config)
  114. if err != nil {
  115. return nil, err
  116. }
  117. return dc, nil
  118. }
  119. // Server accepts a data channel over an SCTP stream
  120. func Server(stream *sctp.Stream, config *Config) (*DataChannel, error) {
  121. buffer := make([]byte, receiveMTU)
  122. n, ppi, err := stream.ReadSCTP(buffer)
  123. if err != nil {
  124. return nil, err
  125. }
  126. if ppi != sctp.PayloadTypeWebRTCDCEP {
  127. return nil, fmt.Errorf("%w %s", ErrInvalidPayloadProtocolIdentifier, ppi)
  128. }
  129. openMsg, err := parseExpectDataChannelOpen(buffer[:n])
  130. if err != nil {
  131. return nil, fmt.Errorf("failed to parse DataChannelOpen packet %w", err)
  132. }
  133. config.ChannelType = openMsg.ChannelType
  134. config.Priority = openMsg.Priority
  135. config.ReliabilityParameter = openMsg.ReliabilityParameter
  136. config.Label = string(openMsg.Label)
  137. config.Protocol = string(openMsg.Protocol)
  138. dataChannel, err := newDataChannel(stream, config)
  139. if err != nil {
  140. return nil, err
  141. }
  142. err = dataChannel.writeDataChannelAck()
  143. if err != nil {
  144. return nil, err
  145. }
  146. err = dataChannel.commitReliabilityParams()
  147. if err != nil {
  148. return nil, err
  149. }
  150. return dataChannel, nil
  151. }
  152. // Read reads a packet of len(p) bytes as binary data
  153. func (c *DataChannel) Read(p []byte) (int, error) {
  154. n, _, err := c.ReadDataChannel(p)
  155. return n, err
  156. }
  157. // ReadDataChannel reads a packet of len(p) bytes
  158. func (c *DataChannel) ReadDataChannel(p []byte) (int, bool, error) {
  159. for {
  160. n, ppi, err := c.stream.ReadSCTP(p)
  161. if errors.Is(err, io.EOF) {
  162. // When the peer sees that an incoming stream was
  163. // reset, it also resets its corresponding outgoing stream.
  164. if closeErr := c.stream.Close(); closeErr != nil {
  165. return 0, false, closeErr
  166. }
  167. }
  168. if err != nil {
  169. return 0, false, err
  170. }
  171. if ppi == sctp.PayloadTypeWebRTCDCEP {
  172. if err = c.handleDCEP(p[:n]); err != nil {
  173. c.log.Errorf("Failed to handle DCEP: %s", err.Error())
  174. }
  175. continue
  176. } else if ppi == sctp.PayloadTypeWebRTCBinaryEmpty || ppi == sctp.PayloadTypeWebRTCStringEmpty {
  177. n = 0
  178. }
  179. atomic.AddUint32(&c.messagesReceived, 1)
  180. atomic.AddUint64(&c.bytesReceived, uint64(n))
  181. isString := ppi == sctp.PayloadTypeWebRTCString || ppi == sctp.PayloadTypeWebRTCStringEmpty
  182. return n, isString, err
  183. }
  184. }
  185. // SetReadDeadline sets a deadline for reads to return
  186. func (c *DataChannel) SetReadDeadline(t time.Time) error {
  187. return c.stream.SetReadDeadline(t)
  188. }
  189. // MessagesSent returns the number of messages sent
  190. func (c *DataChannel) MessagesSent() uint32 {
  191. return atomic.LoadUint32(&c.messagesSent)
  192. }
  193. // MessagesReceived returns the number of messages received
  194. func (c *DataChannel) MessagesReceived() uint32 {
  195. return atomic.LoadUint32(&c.messagesReceived)
  196. }
  197. // OnOpen sets an event handler which is invoked when
  198. // a DATA_CHANNEL_ACK message is received.
  199. // The handler is called only on thefor the channel opened
  200. // https://datatracker.ietf.org/doc/html/draft-ietf-rtcweb-data-protocol-09#section-5.2
  201. func (c *DataChannel) OnOpen(f func()) {
  202. c.mu.Lock()
  203. c.openCompleteHandlerOnce = sync.Once{}
  204. c.onOpenCompleteHandler = f
  205. c.mu.Unlock()
  206. }
  207. func (c *DataChannel) onOpenComplete() {
  208. c.mu.Lock()
  209. hdlr := c.onOpenCompleteHandler
  210. c.mu.Unlock()
  211. if hdlr != nil {
  212. go c.openCompleteHandlerOnce.Do(func() {
  213. hdlr()
  214. })
  215. }
  216. }
  217. // BytesSent returns the number of bytes sent
  218. func (c *DataChannel) BytesSent() uint64 {
  219. return atomic.LoadUint64(&c.bytesSent)
  220. }
  221. // BytesReceived returns the number of bytes received
  222. func (c *DataChannel) BytesReceived() uint64 {
  223. return atomic.LoadUint64(&c.bytesReceived)
  224. }
  225. // StreamIdentifier returns the Stream identifier associated to the stream.
  226. func (c *DataChannel) StreamIdentifier() uint16 {
  227. return c.stream.StreamIdentifier()
  228. }
  229. func (c *DataChannel) handleDCEP(data []byte) error {
  230. msg, err := parse(data)
  231. if err != nil {
  232. return fmt.Errorf("failed to parse DataChannel packet %w", err)
  233. }
  234. switch msg := msg.(type) {
  235. case *channelAck:
  236. c.log.Debug("Received DATA_CHANNEL_ACK")
  237. if err = c.commitReliabilityParams(); err != nil {
  238. return err
  239. }
  240. c.onOpenComplete()
  241. default:
  242. return fmt.Errorf("%w %v", ErrInvalidMessageType, msg)
  243. }
  244. return nil
  245. }
  246. // Write writes len(p) bytes from p as binary data
  247. func (c *DataChannel) Write(p []byte) (n int, err error) {
  248. return c.WriteDataChannel(p, false)
  249. }
  250. // WriteDataChannel writes len(p) bytes from p
  251. func (c *DataChannel) WriteDataChannel(p []byte, isString bool) (n int, err error) {
  252. // https://tools.ietf.org/html/draft-ietf-rtcweb-data-channel-12#section-6.6
  253. // SCTP does not support the sending of empty user messages. Therefore,
  254. // if an empty message has to be sent, the appropriate PPID (WebRTC
  255. // String Empty or WebRTC Binary Empty) is used and the SCTP user
  256. // message of one zero byte is sent. When receiving an SCTP user
  257. // message with one of these PPIDs, the receiver MUST ignore the SCTP
  258. // user message and process it as an empty message.
  259. var ppi sctp.PayloadProtocolIdentifier
  260. switch {
  261. case !isString && len(p) > 0:
  262. ppi = sctp.PayloadTypeWebRTCBinary
  263. case !isString && len(p) == 0:
  264. ppi = sctp.PayloadTypeWebRTCBinaryEmpty
  265. case isString && len(p) > 0:
  266. ppi = sctp.PayloadTypeWebRTCString
  267. case isString && len(p) == 0:
  268. ppi = sctp.PayloadTypeWebRTCStringEmpty
  269. }
  270. atomic.AddUint32(&c.messagesSent, 1)
  271. atomic.AddUint64(&c.bytesSent, uint64(len(p)))
  272. if len(p) == 0 {
  273. _, err := c.stream.WriteSCTP([]byte{0}, ppi)
  274. return 0, err
  275. }
  276. return c.stream.WriteSCTP(p, ppi)
  277. }
  278. func (c *DataChannel) writeDataChannelAck() error {
  279. ack := channelAck{}
  280. ackMsg, err := ack.Marshal()
  281. if err != nil {
  282. return fmt.Errorf("failed to marshal ChannelOpen ACK: %w", err)
  283. }
  284. if _, err = c.stream.WriteSCTP(ackMsg, sctp.PayloadTypeWebRTCDCEP); err != nil {
  285. return fmt.Errorf("failed to send ChannelOpen ACK: %w", err)
  286. }
  287. return err
  288. }
  289. // Close closes the DataChannel and the underlying SCTP stream.
  290. func (c *DataChannel) Close() error {
  291. // https://tools.ietf.org/html/draft-ietf-rtcweb-data-channel-13#section-6.7
  292. // Closing of a data channel MUST be signaled by resetting the
  293. // corresponding outgoing streams [RFC6525]. This means that if one
  294. // side decides to close the data channel, it resets the corresponding
  295. // outgoing stream. When the peer sees that an incoming stream was
  296. // reset, it also resets its corresponding outgoing stream. Once this
  297. // is completed, the data channel is closed. Resetting a stream sets
  298. // the Stream Sequence Numbers (SSNs) of the stream back to 'zero' with
  299. // a corresponding notification to the application layer that the reset
  300. // has been performed. Streams are available for reuse after a reset
  301. // has been performed.
  302. return c.stream.Close()
  303. }
  304. // BufferedAmount returns the number of bytes of data currently queued to be
  305. // sent over this stream.
  306. func (c *DataChannel) BufferedAmount() uint64 {
  307. return c.stream.BufferedAmount()
  308. }
  309. // BufferedAmountLowThreshold returns the number of bytes of buffered outgoing
  310. // data that is considered "low." Defaults to 0.
  311. func (c *DataChannel) BufferedAmountLowThreshold() uint64 {
  312. return c.stream.BufferedAmountLowThreshold()
  313. }
  314. // SetBufferedAmountLowThreshold is used to update the threshold.
  315. // See BufferedAmountLowThreshold().
  316. func (c *DataChannel) SetBufferedAmountLowThreshold(th uint64) {
  317. c.stream.SetBufferedAmountLowThreshold(th)
  318. }
  319. // OnBufferedAmountLow sets the callback handler which would be called when the
  320. // number of bytes of outgoing data buffered is lower than the threshold.
  321. func (c *DataChannel) OnBufferedAmountLow(f func()) {
  322. c.stream.OnBufferedAmountLow(f)
  323. }
  324. func (c *DataChannel) commitReliabilityParams() error {
  325. switch c.Config.ChannelType {
  326. case ChannelTypeReliable:
  327. c.stream.SetReliabilityParams(false, sctp.ReliabilityTypeReliable, c.Config.ReliabilityParameter)
  328. case ChannelTypeReliableUnordered:
  329. c.stream.SetReliabilityParams(true, sctp.ReliabilityTypeReliable, c.Config.ReliabilityParameter)
  330. case ChannelTypePartialReliableRexmit:
  331. c.stream.SetReliabilityParams(false, sctp.ReliabilityTypeRexmit, c.Config.ReliabilityParameter)
  332. case ChannelTypePartialReliableRexmitUnordered:
  333. c.stream.SetReliabilityParams(true, sctp.ReliabilityTypeRexmit, c.Config.ReliabilityParameter)
  334. case ChannelTypePartialReliableTimed:
  335. c.stream.SetReliabilityParams(false, sctp.ReliabilityTypeTimed, c.Config.ReliabilityParameter)
  336. case ChannelTypePartialReliableTimedUnordered:
  337. c.stream.SetReliabilityParams(true, sctp.ReliabilityTypeTimed, c.Config.ReliabilityParameter)
  338. default:
  339. return fmt.Errorf("%w %v", ErrInvalidChannelType, c.Config.ChannelType)
  340. }
  341. return nil
  342. }