stream.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464
  1. package sctp
  2. import (
  3. "errors"
  4. "fmt"
  5. "io"
  6. "math"
  7. "os"
  8. "sync"
  9. "sync/atomic"
  10. "time"
  11. "github.com/pion/logging"
  12. )
  13. const (
  14. // ReliabilityTypeReliable is used for reliable transmission
  15. ReliabilityTypeReliable byte = 0
  16. // ReliabilityTypeRexmit is used for partial reliability by retransmission count
  17. ReliabilityTypeRexmit byte = 1
  18. // ReliabilityTypeTimed is used for partial reliability by retransmission duration
  19. ReliabilityTypeTimed byte = 2
  20. )
  21. // StreamState is an enum for SCTP Stream state field
  22. // This field identifies the state of stream.
  23. type StreamState int
  24. // StreamState enums
  25. const (
  26. StreamStateOpen StreamState = iota // Stream object starts with StreamStateOpen
  27. StreamStateClosing // Outgoing stream is being reset
  28. StreamStateClosed // Stream has been closed
  29. )
  30. func (ss StreamState) String() string {
  31. switch ss {
  32. case StreamStateOpen:
  33. return "open"
  34. case StreamStateClosing:
  35. return "closing"
  36. case StreamStateClosed:
  37. return "closed"
  38. }
  39. return "unknown"
  40. }
  41. // SCTP stream errors
  42. var (
  43. ErrOutboundPacketTooLarge = errors.New("outbound packet larger than maximum message size")
  44. ErrStreamClosed = errors.New("stream closed")
  45. ErrReadDeadlineExceeded = fmt.Errorf("read deadline exceeded: %w", os.ErrDeadlineExceeded)
  46. )
  47. // Stream represents an SCTP stream
  48. type Stream struct {
  49. association *Association
  50. lock sync.RWMutex
  51. streamIdentifier uint16
  52. defaultPayloadType PayloadProtocolIdentifier
  53. reassemblyQueue *reassemblyQueue
  54. sequenceNumber uint16
  55. readNotifier *sync.Cond
  56. readErr error
  57. readTimeoutCancel chan struct{}
  58. unordered bool
  59. reliabilityType byte
  60. reliabilityValue uint32
  61. bufferedAmount uint64
  62. bufferedAmountLow uint64
  63. onBufferedAmountLow func()
  64. state StreamState
  65. log logging.LeveledLogger
  66. name string
  67. }
  68. // StreamIdentifier returns the Stream identifier associated to the stream.
  69. func (s *Stream) StreamIdentifier() uint16 {
  70. s.lock.RLock()
  71. defer s.lock.RUnlock()
  72. return s.streamIdentifier
  73. }
  74. // SetDefaultPayloadType sets the default payload type used by Write.
  75. func (s *Stream) SetDefaultPayloadType(defaultPayloadType PayloadProtocolIdentifier) {
  76. atomic.StoreUint32((*uint32)(&s.defaultPayloadType), uint32(defaultPayloadType))
  77. }
  78. // SetReliabilityParams sets reliability parameters for this stream.
  79. func (s *Stream) SetReliabilityParams(unordered bool, relType byte, relVal uint32) {
  80. s.lock.Lock()
  81. defer s.lock.Unlock()
  82. s.setReliabilityParams(unordered, relType, relVal)
  83. }
  84. // setReliabilityParams sets reliability parameters for this stream.
  85. // The caller should hold the lock.
  86. func (s *Stream) setReliabilityParams(unordered bool, relType byte, relVal uint32) {
  87. s.log.Debugf("[%s] reliability params: ordered=%v type=%d value=%d",
  88. s.name, !unordered, relType, relVal)
  89. s.unordered = unordered
  90. s.reliabilityType = relType
  91. s.reliabilityValue = relVal
  92. }
  93. // Read reads a packet of len(p) bytes, dropping the Payload Protocol Identifier.
  94. // Returns EOF when the stream is reset or an error if the stream is closed
  95. // otherwise.
  96. func (s *Stream) Read(p []byte) (int, error) {
  97. n, _, err := s.ReadSCTP(p)
  98. return n, err
  99. }
  100. // ReadSCTP reads a packet of len(p) bytes and returns the associated Payload
  101. // Protocol Identifier.
  102. // Returns EOF when the stream is reset or an error if the stream is closed
  103. // otherwise.
  104. func (s *Stream) ReadSCTP(p []byte) (int, PayloadProtocolIdentifier, error) {
  105. s.lock.Lock()
  106. defer s.lock.Unlock()
  107. defer func() {
  108. // close readTimeoutCancel if the current read timeout routine is no longer effective
  109. if s.readTimeoutCancel != nil && s.readErr != nil {
  110. close(s.readTimeoutCancel)
  111. s.readTimeoutCancel = nil
  112. }
  113. }()
  114. for {
  115. n, ppi, err := s.reassemblyQueue.read(p)
  116. if err == nil {
  117. return n, ppi, nil
  118. } else if errors.Is(err, io.ErrShortBuffer) {
  119. return 0, PayloadProtocolIdentifier(0), err
  120. }
  121. err = s.readErr
  122. if err != nil {
  123. return 0, PayloadProtocolIdentifier(0), err
  124. }
  125. s.readNotifier.Wait()
  126. }
  127. }
  128. // SetReadDeadline sets the read deadline in an identical way to net.Conn
  129. func (s *Stream) SetReadDeadline(deadline time.Time) error {
  130. s.lock.Lock()
  131. defer s.lock.Unlock()
  132. if s.readTimeoutCancel != nil {
  133. close(s.readTimeoutCancel)
  134. s.readTimeoutCancel = nil
  135. }
  136. if s.readErr != nil {
  137. if !errors.Is(s.readErr, ErrReadDeadlineExceeded) {
  138. return nil
  139. }
  140. s.readErr = nil
  141. }
  142. if !deadline.IsZero() {
  143. s.readTimeoutCancel = make(chan struct{})
  144. go func(readTimeoutCancel chan struct{}) {
  145. t := time.NewTimer(time.Until(deadline))
  146. select {
  147. case <-readTimeoutCancel:
  148. t.Stop()
  149. return
  150. case <-t.C:
  151. s.lock.Lock()
  152. if s.readErr == nil {
  153. s.readErr = ErrReadDeadlineExceeded
  154. }
  155. s.readTimeoutCancel = nil
  156. s.lock.Unlock()
  157. s.readNotifier.Signal()
  158. }
  159. }(s.readTimeoutCancel)
  160. }
  161. return nil
  162. }
  163. func (s *Stream) handleData(pd *chunkPayloadData) {
  164. s.lock.Lock()
  165. defer s.lock.Unlock()
  166. var readable bool
  167. if s.reassemblyQueue.push(pd) {
  168. readable = s.reassemblyQueue.isReadable()
  169. s.log.Debugf("[%s] reassemblyQueue readable=%v", s.name, readable)
  170. if readable {
  171. s.log.Debugf("[%s] readNotifier.signal()", s.name)
  172. s.readNotifier.Signal()
  173. s.log.Debugf("[%s] readNotifier.signal() done", s.name)
  174. }
  175. }
  176. }
  177. func (s *Stream) handleForwardTSNForOrdered(ssn uint16) {
  178. var readable bool
  179. func() {
  180. s.lock.Lock()
  181. defer s.lock.Unlock()
  182. if s.unordered {
  183. return // unordered chunks are handled by handleForwardUnordered method
  184. }
  185. // Remove all chunks older than or equal to the new TSN from
  186. // the reassemblyQueue.
  187. s.reassemblyQueue.forwardTSNForOrdered(ssn)
  188. readable = s.reassemblyQueue.isReadable()
  189. }()
  190. // Notify the reader asynchronously if there's a data chunk to read.
  191. if readable {
  192. s.readNotifier.Signal()
  193. }
  194. }
  195. func (s *Stream) handleForwardTSNForUnordered(newCumulativeTSN uint32) {
  196. var readable bool
  197. func() {
  198. s.lock.Lock()
  199. defer s.lock.Unlock()
  200. if !s.unordered {
  201. return // ordered chunks are handled by handleForwardTSNOrdered method
  202. }
  203. // Remove all chunks older than or equal to the new TSN from
  204. // the reassemblyQueue.
  205. s.reassemblyQueue.forwardTSNForUnordered(newCumulativeTSN)
  206. readable = s.reassemblyQueue.isReadable()
  207. }()
  208. // Notify the reader asynchronously if there's a data chunk to read.
  209. if readable {
  210. s.readNotifier.Signal()
  211. }
  212. }
  213. // Write writes len(p) bytes from p with the default Payload Protocol Identifier
  214. func (s *Stream) Write(p []byte) (n int, err error) {
  215. ppi := PayloadProtocolIdentifier(atomic.LoadUint32((*uint32)(&s.defaultPayloadType)))
  216. return s.WriteSCTP(p, ppi)
  217. }
  218. // WriteSCTP writes len(p) bytes from p to the DTLS connection
  219. func (s *Stream) WriteSCTP(p []byte, ppi PayloadProtocolIdentifier) (int, error) {
  220. maxMessageSize := s.association.MaxMessageSize()
  221. if len(p) > int(maxMessageSize) {
  222. return 0, fmt.Errorf("%w: %v", ErrOutboundPacketTooLarge, math.MaxUint16)
  223. }
  224. if s.State() != StreamStateOpen {
  225. return 0, ErrStreamClosed
  226. }
  227. chunks := s.packetize(p, ppi)
  228. n := len(p)
  229. err := s.association.sendPayloadData(chunks)
  230. if err != nil {
  231. return n, ErrStreamClosed
  232. }
  233. return n, nil
  234. }
  235. func (s *Stream) packetize(raw []byte, ppi PayloadProtocolIdentifier) []*chunkPayloadData {
  236. s.lock.Lock()
  237. defer s.lock.Unlock()
  238. i := uint32(0)
  239. remaining := uint32(len(raw))
  240. // From draft-ietf-rtcweb-data-protocol-09, section 6:
  241. // All Data Channel Establishment Protocol messages MUST be sent using
  242. // ordered delivery and reliable transmission.
  243. unordered := ppi != PayloadTypeWebRTCDCEP && s.unordered
  244. var chunks []*chunkPayloadData
  245. var head *chunkPayloadData
  246. for remaining != 0 {
  247. fragmentSize := min32(s.association.maxPayloadSize, remaining)
  248. // Copy the userdata since we'll have to store it until acked
  249. // and the caller may re-use the buffer in the mean time
  250. userData := make([]byte, fragmentSize)
  251. copy(userData, raw[i:i+fragmentSize])
  252. chunk := &chunkPayloadData{
  253. streamIdentifier: s.streamIdentifier,
  254. userData: userData,
  255. unordered: unordered,
  256. beginningFragment: i == 0,
  257. endingFragment: remaining-fragmentSize == 0,
  258. immediateSack: false,
  259. payloadType: ppi,
  260. streamSequenceNumber: s.sequenceNumber,
  261. head: head,
  262. }
  263. if head == nil {
  264. head = chunk
  265. }
  266. chunks = append(chunks, chunk)
  267. remaining -= fragmentSize
  268. i += fragmentSize
  269. }
  270. // RFC 4960 Sec 6.6
  271. // Note: When transmitting ordered and unordered data, an endpoint does
  272. // not increment its Stream Sequence Number when transmitting a DATA
  273. // chunk with U flag set to 1.
  274. if !unordered {
  275. s.sequenceNumber++
  276. }
  277. s.bufferedAmount += uint64(len(raw))
  278. s.log.Tracef("[%s] bufferedAmount = %d", s.name, s.bufferedAmount)
  279. return chunks
  280. }
  281. // Close closes the write-direction of the stream.
  282. // Future calls to Write are not permitted after calling Close.
  283. func (s *Stream) Close() error {
  284. if sid, resetOutbound := func() (uint16, bool) {
  285. s.lock.Lock()
  286. defer s.lock.Unlock()
  287. s.log.Debugf("[%s] Close: state=%s", s.name, s.state.String())
  288. if s.state == StreamStateOpen {
  289. if s.readErr == nil {
  290. s.state = StreamStateClosing
  291. } else {
  292. s.state = StreamStateClosed
  293. }
  294. s.log.Debugf("[%s] state change: open => %s", s.name, s.state.String())
  295. return s.streamIdentifier, true
  296. }
  297. return s.streamIdentifier, false
  298. }(); resetOutbound {
  299. // Reset the outgoing stream
  300. // https://tools.ietf.org/html/rfc6525
  301. return s.association.sendResetRequest(sid)
  302. }
  303. return nil
  304. }
  305. // BufferedAmount returns the number of bytes of data currently queued to be sent over this stream.
  306. func (s *Stream) BufferedAmount() uint64 {
  307. s.lock.RLock()
  308. defer s.lock.RUnlock()
  309. return s.bufferedAmount
  310. }
  311. // BufferedAmountLowThreshold returns the number of bytes of buffered outgoing data that is
  312. // considered "low." Defaults to 0.
  313. func (s *Stream) BufferedAmountLowThreshold() uint64 {
  314. s.lock.RLock()
  315. defer s.lock.RUnlock()
  316. return s.bufferedAmountLow
  317. }
  318. // SetBufferedAmountLowThreshold is used to update the threshold.
  319. // See BufferedAmountLowThreshold().
  320. func (s *Stream) SetBufferedAmountLowThreshold(th uint64) {
  321. s.lock.Lock()
  322. defer s.lock.Unlock()
  323. s.bufferedAmountLow = th
  324. }
  325. // OnBufferedAmountLow sets the callback handler which would be called when the number of
  326. // bytes of outgoing data buffered is lower than the threshold.
  327. func (s *Stream) OnBufferedAmountLow(f func()) {
  328. s.lock.Lock()
  329. defer s.lock.Unlock()
  330. s.onBufferedAmountLow = f
  331. }
  332. // This method is called by association's readLoop (go-)routine to notify this stream
  333. // of the specified amount of outgoing data has been delivered to the peer.
  334. func (s *Stream) onBufferReleased(nBytesReleased int) {
  335. if nBytesReleased <= 0 {
  336. return
  337. }
  338. s.lock.Lock()
  339. fromAmount := s.bufferedAmount
  340. if s.bufferedAmount < uint64(nBytesReleased) {
  341. s.bufferedAmount = 0
  342. s.log.Errorf("[%s] released buffer size %d should be <= %d",
  343. s.name, nBytesReleased, s.bufferedAmount)
  344. } else {
  345. s.bufferedAmount -= uint64(nBytesReleased)
  346. }
  347. s.log.Tracef("[%s] bufferedAmount = %d", s.name, s.bufferedAmount)
  348. if s.onBufferedAmountLow != nil && fromAmount > s.bufferedAmountLow && s.bufferedAmount <= s.bufferedAmountLow {
  349. f := s.onBufferedAmountLow
  350. s.lock.Unlock()
  351. f()
  352. return
  353. }
  354. s.lock.Unlock()
  355. }
  356. func (s *Stream) getNumBytesInReassemblyQueue() int {
  357. // No lock is required as it reads the size with atomic load function.
  358. return s.reassemblyQueue.getNumBytes()
  359. }
  360. func (s *Stream) onInboundStreamReset() {
  361. s.lock.Lock()
  362. defer s.lock.Unlock()
  363. s.log.Debugf("[%s] onInboundStreamReset: state=%s", s.name, s.state.String())
  364. // No more inbound data to read. Unblock the read with io.EOF.
  365. // This should cause DCEP layer (datachannel package) to call Close() which
  366. // will reset outgoing stream also.
  367. // See RFC 8831 section 6.7:
  368. // if one side decides to close the data channel, it resets the corresponding
  369. // outgoing stream. When the peer sees that an incoming stream was
  370. // reset, it also resets its corresponding outgoing stream. Once this
  371. // is completed, the data channel is closed.
  372. s.readErr = io.EOF
  373. s.readNotifier.Broadcast()
  374. if s.state == StreamStateClosing {
  375. s.log.Debugf("[%s] state change: closing => closed", s.name)
  376. s.state = StreamStateClosed
  377. }
  378. }
  379. // State return the stream state.
  380. func (s *Stream) State() StreamState {
  381. s.lock.RLock()
  382. defer s.lock.RUnlock()
  383. return s.state
  384. }