utils.go 7.3 KB


  1. // Copyright (c) 2024 Winlin
  2. //
  3. // SPDX-License-Identifier: MIT
  4. package main
  5. import (
  6. "context"
  7. "encoding/binary"
  8. "encoding/json"
  9. stdErr "errors"
  10. "fmt"
  11. "io"
  12. "io/ioutil"
  13. "net"
  14. "net/http"
  15. "net/url"
  16. "os"
  17. "path"
  18. "reflect"
  19. "regexp"
  20. "strconv"
  21. "strings"
  22. "syscall"
  23. "time"
  24. "srs-proxy/errors"
  25. "srs-proxy/logger"
  26. )
  27. func apiResponse(ctx context.Context, w http.ResponseWriter, r *http.Request, data any) {
  28. w.Header().Set("Server", fmt.Sprintf("%v/%v", Signature(), Version()))
  29. b, err := json.Marshal(data)
  30. if err != nil {
  31. apiError(ctx, w, r, errors.Wrapf(err, "marshal %v %v", reflect.TypeOf(data), data))
  32. return
  33. }
  34. w.Header().Set("Content-Type", "application/json")
  35. w.WriteHeader(http.StatusOK)
  36. w.Write(b)
  37. }
  38. func apiError(ctx context.Context, w http.ResponseWriter, r *http.Request, err error) {
  39. logger.Wf(ctx, "HTTP API error %+v", err)
  40. w.Header().Set("Content-Type", "text/plain; charset=utf-8")
  41. w.WriteHeader(http.StatusInternalServerError)
  42. fmt.Fprintln(w, fmt.Sprintf("%v", err))
  43. }
  44. func apiCORS(ctx context.Context, w http.ResponseWriter, r *http.Request) bool {
  45. // Always support CORS. Note that browser may send origin header for m3u8, but no origin header
  46. // for ts. So we always response CORS header.
  47. if true {
  48. // SRS does not need cookie or credentials, so we disable CORS credentials, and use * for CORS origin,
  49. // headers, expose headers and methods.
  50. w.Header().Set("Access-Control-Allow-Origin", "*")
  51. // See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers
  52. w.Header().Set("Access-Control-Allow-Headers", "*")
  53. // See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Methods
  54. w.Header().Set("Access-Control-Allow-Methods", "*")
  55. }
  56. if r.Method == http.MethodOptions {
  57. w.WriteHeader(http.StatusOK)
  58. return true
  59. }
  60. return false
  61. }
  62. func parseGracefullyQuitTimeout() (time.Duration, error) {
  63. if t, err := time.ParseDuration(envGraceQuitTimeout()); err != nil {
  64. return 0, errors.Wrapf(err, "parse duration %v", envGraceQuitTimeout())
  65. } else {
  66. return t, nil
  67. }
  68. }
  69. // ParseBody read the body from r, and unmarshal JSON to v.
  70. func ParseBody(r io.ReadCloser, v interface{}) error {
  71. b, err := ioutil.ReadAll(r)
  72. if err != nil {
  73. return errors.Wrapf(err, "read body")
  74. }
  75. defer r.Close()
  76. if len(b) == 0 {
  77. return nil
  78. }
  79. if err := json.Unmarshal(b, v); err != nil {
  80. return errors.Wrapf(err, "json unmarshal %v", string(b))
  81. }
  82. return nil
  83. }
  84. // buildStreamURL build as vhost/app/stream for stream URL r.
  85. func buildStreamURL(r string) (string, error) {
  86. u, err := url.Parse(r)
  87. if err != nil {
  88. return "", errors.Wrapf(err, "parse url %v", r)
  89. }
  90. // If not domain or ip in hostname, it's __defaultVhost__.
  91. defaultVhost := !strings.Contains(u.Hostname(), ".")
  92. // If hostname is actually an IP address, it's __defaultVhost__.
  93. if ip := net.ParseIP(u.Hostname()); ip.To4() != nil {
  94. defaultVhost = true
  95. }
  96. if defaultVhost {
  97. return fmt.Sprintf("__defaultVhost__%v", u.Path), nil
  98. }
  99. // Ignore port, only use hostname as vhost.
  100. return fmt.Sprintf("%v%v", u.Hostname(), u.Path), nil
  101. }
  102. // isPeerClosedError indicates whether peer object closed the connection.
  103. func isPeerClosedError(err error) bool {
  104. causeErr := errors.Cause(err)
  105. if stdErr.Is(causeErr, io.EOF) {
  106. return true
  107. }
  108. if stdErr.Is(causeErr, syscall.EPIPE) {
  109. return true
  110. }
  111. if netErr, ok := causeErr.(*net.OpError); ok {
  112. if sysErr, ok := netErr.Err.(*os.SyscallError); ok {
  113. if stdErr.Is(sysErr.Err, syscall.ECONNRESET) {
  114. return true
  115. }
  116. }
  117. }
  118. return false
  119. }
  120. // convertURLToStreamURL convert the URL in HTTP request to special URLs. The unifiedURL is the URL
  121. // in unified, foramt as scheme://vhost/app/stream without extensions. While the fullURL is the unifiedURL
  122. // with extension.
  123. func convertURLToStreamURL(r *http.Request) (unifiedURL, fullURL string) {
  124. scheme := "http"
  125. if r.TLS != nil {
  126. scheme = "https"
  127. }
  128. hostname := "__defaultVhost__"
  129. if strings.Contains(r.Host, ":") {
  130. if v, _, err := net.SplitHostPort(r.Host); err == nil {
  131. hostname = v
  132. }
  133. }
  134. var appStream, streamExt string
  135. // Parse app/stream from query string.
  136. q := r.URL.Query()
  137. if app := q.Get("app"); app != "" {
  138. appStream = "/" + app
  139. }
  140. if stream := q.Get("stream"); stream != "" {
  141. appStream = fmt.Sprintf("%v/%v", appStream, stream)
  142. }
  143. // Parse app/stream from path.
  144. if appStream == "" {
  145. streamExt = path.Ext(r.URL.Path)
  146. appStream = strings.TrimSuffix(r.URL.Path, streamExt)
  147. }
  148. unifiedURL = fmt.Sprintf("%v://%v%v", scheme, hostname, appStream)
  149. fullURL = fmt.Sprintf("%v%v", unifiedURL, streamExt)
  150. return
  151. }
  152. // rtcIsSTUN returns true if data of UDP payload is a STUN packet.
  153. func rtcIsSTUN(data []byte) bool {
  154. return len(data) > 0 && (data[0] == 0 || data[0] == 1)
  155. }
  156. // rtcIsRTPOrRTCP returns true if data of UDP payload is a RTP or RTCP packet.
  157. func rtcIsRTPOrRTCP(data []byte) bool {
  158. return len(data) >= 12 && (data[0]&0xC0) == 0x80
  159. }
  160. // srtIsHandshake returns true if data of UDP payload is a SRT handshake packet.
  161. func srtIsHandshake(data []byte) bool {
  162. return len(data) >= 4 && binary.BigEndian.Uint32(data) == 0x80000000
  163. }
  164. // srtParseSocketID parse the socket id from the SRT packet.
  165. func srtParseSocketID(data []byte) uint32 {
  166. if len(data) >= 16 {
  167. return binary.BigEndian.Uint32(data[12:])
  168. }
  169. return 0
  170. }
  171. // parseIceUfragPwd parse the ice-ufrag and ice-pwd from the SDP.
  172. func parseIceUfragPwd(sdp string) (ufrag, pwd string, err error) {
  173. if true {
  174. ufragRe := regexp.MustCompile(`a=ice-ufrag:([^\s]+)`)
  175. ufragMatch := ufragRe.FindStringSubmatch(sdp)
  176. if len(ufragMatch) <= 1 {
  177. return "", "", errors.Errorf("no ice-ufrag in sdp %v", sdp)
  178. }
  179. ufrag = ufragMatch[1]
  180. }
  181. if true {
  182. pwdRe := regexp.MustCompile(`a=ice-pwd:([^\s]+)`)
  183. pwdMatch := pwdRe.FindStringSubmatch(sdp)
  184. if len(pwdMatch) <= 1 {
  185. return "", "", errors.Errorf("no ice-pwd in sdp %v", sdp)
  186. }
  187. pwd = pwdMatch[1]
  188. }
  189. return ufrag, pwd, nil
  190. }
  191. // parseSRTStreamID parse the SRT stream id to host(optional) and resource(required).
  192. // See https://ossrs.io/lts/en-us/docs/v7/doc/srt#srt-url
  193. func parseSRTStreamID(sid string) (host, resource string, err error) {
  194. if true {
  195. hostRe := regexp.MustCompile(`h=([^,]+)`)
  196. hostMatch := hostRe.FindStringSubmatch(sid)
  197. if len(hostMatch) > 1 {
  198. host = hostMatch[1]
  199. }
  200. }
  201. if true {
  202. resourceRe := regexp.MustCompile(`r=([^,]+)`)
  203. resourceMatch := resourceRe.FindStringSubmatch(sid)
  204. if len(resourceMatch) <= 1 {
  205. return "", "", errors.Errorf("no resource in sid %v", sid)
  206. }
  207. resource = resourceMatch[1]
  208. }
  209. return host, resource, nil
  210. }
  211. // parseListenEndpoint parse the listen endpoint as:
  212. // port The tcp listen port, like 1935.
  213. // protocol://ip:port The listen endpoint, like tcp://:1935 or tcp://0.0.0.0:1935
  214. func parseListenEndpoint(ep string) (protocol string, ip net.IP, port uint16, err error) {
  215. // If no colon in ep, it's port in string.
  216. if !strings.Contains(ep, ":") {
  217. if p, err := strconv.Atoi(ep); err != nil {
  218. return "", nil, 0, errors.Wrapf(err, "parse port %v", ep)
  219. } else {
  220. return "tcp", nil, uint16(p), nil
  221. }
  222. }
  223. // Must be protocol://ip:port schema.
  224. parts := strings.Split(ep, ":")
  225. if len(parts) != 3 {
  226. return "", nil, 0, errors.Errorf("invalid endpoint %v", ep)
  227. }
  228. if p, err := strconv.Atoi(parts[2]); err != nil {
  229. return "", nil, 0, errors.Wrapf(err, "parse port %v", parts[2])
  230. } else {
  231. return parts[0], net.ParseIP(parts[1]), uint16(p), nil
  232. }
  233. }