reassembly_queue.go 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352
  1. package sctp
  2. import (
  3. "errors"
  4. "io"
  5. "sort"
  6. "sync/atomic"
  7. )
  8. func sortChunksByTSN(a []*chunkPayloadData) {
  9. sort.Slice(a, func(i, j int) bool {
  10. return sna32LT(a[i].tsn, a[j].tsn)
  11. })
  12. }
  13. func sortChunksBySSN(a []*chunkSet) {
  14. sort.Slice(a, func(i, j int) bool {
  15. return sna16LT(a[i].ssn, a[j].ssn)
  16. })
  17. }
  18. // chunkSet is a set of chunks that share the same SSN
  19. type chunkSet struct {
  20. ssn uint16 // used only with the ordered chunks
  21. ppi PayloadProtocolIdentifier
  22. chunks []*chunkPayloadData
  23. }
  24. func newChunkSet(ssn uint16, ppi PayloadProtocolIdentifier) *chunkSet {
  25. return &chunkSet{
  26. ssn: ssn,
  27. ppi: ppi,
  28. chunks: []*chunkPayloadData{},
  29. }
  30. }
  31. func (set *chunkSet) push(chunk *chunkPayloadData) bool {
  32. // check if dup
  33. for _, c := range set.chunks {
  34. if c.tsn == chunk.tsn {
  35. return false
  36. }
  37. }
  38. // append and sort
  39. set.chunks = append(set.chunks, chunk)
  40. sortChunksByTSN(set.chunks)
  41. // Check if we now have a complete set
  42. complete := set.isComplete()
  43. return complete
  44. }
  45. func (set *chunkSet) isComplete() bool {
  46. // Condition for complete set
  47. // 0. Has at least one chunk.
  48. // 1. Begins with beginningFragment set to true
  49. // 2. Ends with endingFragment set to true
  50. // 3. TSN monotinically increase by 1 from beginning to end
  51. // 0.
  52. nChunks := len(set.chunks)
  53. if nChunks == 0 {
  54. return false
  55. }
  56. // 1.
  57. if !set.chunks[0].beginningFragment {
  58. return false
  59. }
  60. // 2.
  61. if !set.chunks[nChunks-1].endingFragment {
  62. return false
  63. }
  64. // 3.
  65. var lastTSN uint32
  66. for i, c := range set.chunks {
  67. if i > 0 {
  68. // Fragments must have contiguous TSN
  69. // From RFC 4960 Section 3.3.1:
  70. // When a user message is fragmented into multiple chunks, the TSNs are
  71. // used by the receiver to reassemble the message. This means that the
  72. // TSNs for each fragment of a fragmented user message MUST be strictly
  73. // sequential.
  74. if c.tsn != lastTSN+1 {
  75. // mid or end fragment is missing
  76. return false
  77. }
  78. }
  79. lastTSN = c.tsn
  80. }
  81. return true
  82. }
  83. type reassemblyQueue struct {
  84. si uint16
  85. nextSSN uint16 // expected SSN for next ordered chunk
  86. ordered []*chunkSet
  87. unordered []*chunkSet
  88. unorderedChunks []*chunkPayloadData
  89. nBytes uint64
  90. }
  91. var errTryAgain = errors.New("try again")
  92. func newReassemblyQueue(si uint16) *reassemblyQueue {
  93. // From RFC 4960 Sec 6.5:
  94. // The Stream Sequence Number in all the streams MUST start from 0 when
  95. // the association is established. Also, when the Stream Sequence
  96. // Number reaches the value 65535 the next Stream Sequence Number MUST
  97. // be set to 0.
  98. return &reassemblyQueue{
  99. si: si,
  100. nextSSN: 0, // From RFC 4960 Sec 6.5:
  101. ordered: make([]*chunkSet, 0),
  102. unordered: make([]*chunkSet, 0),
  103. }
  104. }
  105. func (r *reassemblyQueue) push(chunk *chunkPayloadData) bool {
  106. var cset *chunkSet
  107. if chunk.streamIdentifier != r.si {
  108. return false
  109. }
  110. if chunk.unordered {
  111. // First, insert into unorderedChunks array
  112. r.unorderedChunks = append(r.unorderedChunks, chunk)
  113. atomic.AddUint64(&r.nBytes, uint64(len(chunk.userData)))
  114. sortChunksByTSN(r.unorderedChunks)
  115. // Scan unorderedChunks that are contiguous (in TSN)
  116. cset = r.findCompleteUnorderedChunkSet()
  117. // If found, append the complete set to the unordered array
  118. if cset != nil {
  119. r.unordered = append(r.unordered, cset)
  120. return true
  121. }
  122. return false
  123. }
  124. // This is an ordered chunk
  125. if sna16LT(chunk.streamSequenceNumber, r.nextSSN) {
  126. return false
  127. }
  128. // Check if a chunkSet with the SSN already exists
  129. for _, set := range r.ordered {
  130. if set.ssn == chunk.streamSequenceNumber {
  131. cset = set
  132. break
  133. }
  134. }
  135. // If not found, create a new chunkSet
  136. if cset == nil {
  137. cset = newChunkSet(chunk.streamSequenceNumber, chunk.payloadType)
  138. r.ordered = append(r.ordered, cset)
  139. if !chunk.unordered {
  140. sortChunksBySSN(r.ordered)
  141. }
  142. }
  143. atomic.AddUint64(&r.nBytes, uint64(len(chunk.userData)))
  144. return cset.push(chunk)
  145. }
  146. func (r *reassemblyQueue) findCompleteUnorderedChunkSet() *chunkSet {
  147. startIdx := -1
  148. nChunks := 0
  149. var lastTSN uint32
  150. var found bool
  151. for i, c := range r.unorderedChunks {
  152. // seek beigining
  153. if c.beginningFragment {
  154. startIdx = i
  155. nChunks = 1
  156. lastTSN = c.tsn
  157. if c.endingFragment {
  158. found = true
  159. break
  160. }
  161. continue
  162. }
  163. if startIdx < 0 {
  164. continue
  165. }
  166. // Check if contiguous in TSN
  167. if c.tsn != lastTSN+1 {
  168. startIdx = -1
  169. continue
  170. }
  171. lastTSN = c.tsn
  172. nChunks++
  173. if c.endingFragment {
  174. found = true
  175. break
  176. }
  177. }
  178. if !found {
  179. return nil
  180. }
  181. // Extract the range of chunks
  182. var chunks []*chunkPayloadData
  183. chunks = append(chunks, r.unorderedChunks[startIdx:startIdx+nChunks]...)
  184. r.unorderedChunks = append(
  185. r.unorderedChunks[:startIdx],
  186. r.unorderedChunks[startIdx+nChunks:]...)
  187. chunkSet := newChunkSet(0, chunks[0].payloadType)
  188. chunkSet.chunks = chunks
  189. return chunkSet
  190. }
  191. func (r *reassemblyQueue) isReadable() bool {
  192. // Check unordered first
  193. if len(r.unordered) > 0 {
  194. // The chunk sets in r.unordered should all be complete.
  195. return true
  196. }
  197. // Check ordered sets
  198. if len(r.ordered) > 0 {
  199. cset := r.ordered[0]
  200. if cset.isComplete() {
  201. if sna16LTE(cset.ssn, r.nextSSN) {
  202. return true
  203. }
  204. }
  205. }
  206. return false
  207. }
  208. func (r *reassemblyQueue) read(buf []byte) (int, PayloadProtocolIdentifier, error) {
  209. var cset *chunkSet
  210. // Check unordered first
  211. switch {
  212. case len(r.unordered) > 0:
  213. cset = r.unordered[0]
  214. r.unordered = r.unordered[1:]
  215. case len(r.ordered) > 0:
  216. // Now, check ordered
  217. cset = r.ordered[0]
  218. if !cset.isComplete() {
  219. return 0, 0, errTryAgain
  220. }
  221. if sna16GT(cset.ssn, r.nextSSN) {
  222. return 0, 0, errTryAgain
  223. }
  224. r.ordered = r.ordered[1:]
  225. if cset.ssn == r.nextSSN {
  226. r.nextSSN++
  227. }
  228. default:
  229. return 0, 0, errTryAgain
  230. }
  231. // Concat all fragments into the buffer
  232. nWritten := 0
  233. ppi := cset.ppi
  234. var err error
  235. for _, c := range cset.chunks {
  236. toCopy := len(c.userData)
  237. r.subtractNumBytes(toCopy)
  238. if err == nil {
  239. n := copy(buf[nWritten:], c.userData)
  240. nWritten += n
  241. if n < toCopy {
  242. err = io.ErrShortBuffer
  243. }
  244. }
  245. }
  246. return nWritten, ppi, err
  247. }
  248. func (r *reassemblyQueue) forwardTSNForOrdered(lastSSN uint16) {
  249. // Use lastSSN to locate a chunkSet then remove it if the set has
  250. // not been complete
  251. keep := []*chunkSet{}
  252. for _, set := range r.ordered {
  253. if sna16LTE(set.ssn, lastSSN) {
  254. if !set.isComplete() {
  255. // drop the set
  256. for _, c := range set.chunks {
  257. r.subtractNumBytes(len(c.userData))
  258. }
  259. continue
  260. }
  261. }
  262. keep = append(keep, set)
  263. }
  264. r.ordered = keep
  265. // Finally, forward nextSSN
  266. if sna16LTE(r.nextSSN, lastSSN) {
  267. r.nextSSN = lastSSN + 1
  268. }
  269. }
  270. func (r *reassemblyQueue) forwardTSNForUnordered(newCumulativeTSN uint32) {
  271. // Remove all fragments in the unordered sets that contains chunks
  272. // equal to or older than `newCumulativeTSN`.
  273. // We know all sets in the r.unordered are complete ones.
  274. // Just remove chunks that are equal to or older than newCumulativeTSN
  275. // from the unorderedChunks
  276. lastIdx := -1
  277. for i, c := range r.unorderedChunks {
  278. if sna32GT(c.tsn, newCumulativeTSN) {
  279. break
  280. }
  281. lastIdx = i
  282. }
  283. if lastIdx >= 0 {
  284. for _, c := range r.unorderedChunks[0 : lastIdx+1] {
  285. r.subtractNumBytes(len(c.userData))
  286. }
  287. r.unorderedChunks = r.unorderedChunks[lastIdx+1:]
  288. }
  289. }
  290. func (r *reassemblyQueue) subtractNumBytes(nBytes int) {
  291. cur := atomic.LoadUint64(&r.nBytes)
  292. if int(cur) >= nBytes {
  293. atomic.AddUint64(&r.nBytes, -uint64(nBytes))
  294. } else {
  295. atomic.StoreUint64(&r.nBytes, 0)
  296. }
  297. }
  298. func (r *reassemblyQueue) getNumBytes() int {
  299. return int(atomic.LoadUint64(&r.nBytes))
  300. }