ts-demuxer.go 5.5 KB


  1. package mpeg2
  2. import (
  3. "errors"
  4. "io"
  5. "github.com/yapingcat/gomedia/codec"
  6. )
  7. type pakcet_t struct {
  8. payload []byte
  9. pts uint64
  10. dts uint64
  11. }
  12. func newPacket_t(size uint32) *pakcet_t {
  13. return &pakcet_t{
  14. payload: make([]byte, 0, size),
  15. pts: 0,
  16. dts: 0,
  17. }
  18. }
  19. type tsstream struct {
  20. cid TS_STREAM_TYPE
  21. pes_sid PES_STREMA_ID
  22. pes_pkg *PesPacket
  23. pkg *pakcet_t
  24. }
  25. type tsprogram struct {
  26. pn uint16
  27. streams map[uint16]*tsstream
  28. }
  29. type TSDemuxer struct {
  30. programs map[uint16]*tsprogram
  31. OnFrame func(cid TS_STREAM_TYPE, frame []byte, pts uint64, dts uint64)
  32. OnTSPacket func(pkg *TSPacket)
  33. }
  34. func NewTSDemuxer() *TSDemuxer {
  35. return &TSDemuxer{
  36. programs: make(map[uint16]*tsprogram),
  37. OnFrame: nil,
  38. OnTSPacket: nil,
  39. }
  40. }
  41. func (demuxer *TSDemuxer) Input(r io.Reader) error {
  42. buf := make([]byte, TS_PAKCET_SIZE)
  43. _, err := io.ReadFull(r, buf)
  44. if err != nil {
  45. return errNeedMore
  46. }
  47. for {
  48. bs := codec.NewBitStream(buf)
  49. var pkg TSPacket
  50. if err := pkg.DecodeHeader(bs); err != nil {
  51. return err
  52. }
  53. if pkg.PID == uint16(TS_PID_PAT) {
  54. if pkg.Payload_unit_start_indicator == 1 {
  55. bs.SkipBits(8)
  56. }
  57. pat := NewPat()
  58. if err := pat.Decode(bs); err != nil {
  59. return err
  60. }
  61. pkg.Payload = pat
  62. if pat.Table_id != uint8(TS_TID_PAS) {
  63. return errors.New("pat table id is wrong")
  64. }
  65. for _, pmt := range pat.Pmts {
  66. if pmt.Program_number != 0x0000 {
  67. if _, found := demuxer.programs[pmt.PID]; !found {
  68. demuxer.programs[pmt.PID] = &tsprogram{pn: 0, streams: make(map[uint16]*tsstream)}
  69. }
  70. }
  71. }
  72. } else {
  73. for p, s := range demuxer.programs {
  74. if p == pkg.PID { // pmt table
  75. if pkg.Payload_unit_start_indicator == 1 {
  76. bs.SkipBits(8) //pointer filed
  77. }
  78. pmt := NewPmt()
  79. if err := pmt.Decode(bs); err != nil {
  80. return err
  81. }
  82. pkg.Payload = pmt
  83. s.pn = pmt.Program_number
  84. for _, ps := range pmt.Streams {
  85. if _, found := s.streams[ps.Elementary_PID]; !found {
  86. s.streams[ps.Elementary_PID] = &tsstream{
  87. cid: TS_STREAM_TYPE(ps.StreamType),
  88. pes_sid: findPESIDByStreamType(TS_STREAM_TYPE(ps.StreamType)),
  89. pes_pkg: NewPesPacket(),
  90. }
  91. }
  92. }
  93. } else {
  94. for sid, stream := range s.streams {
  95. if sid != pkg.PID {
  96. continue
  97. }
  98. if pkg.Payload_unit_start_indicator == 1 {
  99. err := stream.pes_pkg.Decode(bs)
  100. // ignore error if it was a short payload read, next ts packet should append missing data
  101. if err != nil && !(errors.Is(err, errNeedMore) && stream.pes_pkg.Pes_payload != nil) {
  102. return err
  103. }
  104. pkg.Payload = stream.pes_pkg
  105. } else {
  106. stream.pes_pkg.Pes_payload = bs.RemainData()
  107. pkg.Payload = bs.RemainData()
  108. }
  109. stype := findPESIDByStreamType(stream.cid)
  110. if stype == PES_STREAM_AUDIO {
  111. demuxer.doAudioPesPacket(stream, pkg.Payload_unit_start_indicator)
  112. } else if stype == PES_STREAM_VIDEO {
  113. demuxer.doVideoPesPacket(stream, pkg.Payload_unit_start_indicator)
  114. }
  115. }
  116. }
  117. }
  118. }
  119. if demuxer.OnTSPacket != nil {
  120. demuxer.OnTSPacket(&pkg)
  121. }
  122. _, err := io.ReadFull(r, buf)
  123. if err != nil {
  124. if errors.Is(err, io.EOF) {
  125. break
  126. } else {
  127. return errNeedMore
  128. }
  129. }
  130. }
  131. demuxer.flush()
  132. return nil
  133. }
  134. func (demuxer *TSDemuxer) flush() {
  135. for _, pm := range demuxer.programs {
  136. for _, stream := range pm.streams {
  137. if stream.pkg == nil || len(stream.pkg.payload) == 0 {
  138. continue
  139. }
  140. if demuxer.OnFrame != nil {
  141. demuxer.OnFrame(stream.cid, stream.pkg.payload, stream.pkg.pts/90, stream.pkg.dts/90)
  142. }
  143. }
  144. }
  145. }
  146. func (demuxer *TSDemuxer) doVideoPesPacket(stream *tsstream, start uint8) {
  147. if stream.cid != TS_STREAM_H264 && stream.cid != TS_STREAM_H265 {
  148. return
  149. }
  150. if stream.pkg == nil {
  151. stream.pkg = newPacket_t(1024)
  152. stream.pkg.pts = stream.pes_pkg.Pts
  153. stream.pkg.dts = stream.pes_pkg.Dts
  154. }
  155. stream.pkg.payload = append(stream.pkg.payload, stream.pes_pkg.Pes_payload...)
  156. demuxer.splitH26XFrame(stream)
  157. stream.pkg.pts = stream.pes_pkg.Pts
  158. stream.pkg.dts = stream.pes_pkg.Dts
  159. }
  160. func (demuxer *TSDemuxer) doAudioPesPacket(stream *tsstream, start uint8) {
  161. if stream.cid != TS_STREAM_AAC {
  162. return
  163. }
  164. if stream.pkg == nil {
  165. stream.pkg = newPacket_t(1024)
  166. stream.pkg.pts = stream.pes_pkg.Pts
  167. stream.pkg.dts = stream.pes_pkg.Dts
  168. }
  169. if len(stream.pkg.payload) > 0 && (start == 1 || stream.pes_pkg.Pts != stream.pkg.pts) {
  170. if demuxer.OnFrame != nil {
  171. demuxer.OnFrame(stream.cid, stream.pkg.payload, stream.pkg.pts/90, stream.pkg.dts/90)
  172. }
  173. stream.pkg.payload = stream.pkg.payload[:0]
  174. }
  175. stream.pkg.payload = append(stream.pkg.payload, stream.pes_pkg.Pes_payload...)
  176. stream.pkg.pts = stream.pes_pkg.Pts
  177. stream.pkg.dts = stream.pes_pkg.Dts
  178. }
  179. func (demuxer *TSDemuxer) splitH26XFrame(stream *tsstream) {
  180. data := stream.pkg.payload
  181. start, _ := codec.FindStartCode(data, 0)
  182. datalen := len(data)
  183. for start < datalen {
  184. end, _ := codec.FindStartCode(data, start+3)
  185. if end < 0 {
  186. break
  187. }
  188. if (stream.cid == TS_STREAM_H264 && codec.H264NaluTypeWithoutStartCode(data[start:end]) == codec.H264_NAL_AUD) ||
  189. (stream.cid == TS_STREAM_H265 && codec.H265NaluTypeWithoutStartCode(data[start:end]) == codec.H265_NAL_AUD) {
  190. start = end
  191. continue
  192. }
  193. if demuxer.OnFrame != nil {
  194. demuxer.OnFrame(stream.cid, data[start:end], stream.pkg.pts/90, stream.pkg.dts/90)
  195. }
  196. start = end
  197. }
  198. if start == 0 {
  199. return
  200. }
  201. copy(stream.pkg.payload, data[start:datalen])
  202. stream.pkg.payload = stream.pkg.payload[0 : datalen-start]
  203. }