123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655 |
- // Copyright (c) 2024 Winlin
- //
- // SPDX-License-Identifier: MIT
- package main
- import (
- "context"
- "fmt"
- "math/rand"
- "net"
- "strconv"
- "strings"
- "sync"
- "time"
- "srs-proxy/errors"
- "srs-proxy/logger"
- "srs-proxy/rtmp"
- )
- // srsRTMPServer is the proxy for SRS RTMP server, to proxy the RTMP stream to backend SRS
- // server. It will figure out the backend server to proxy to. Unlike the edge server, it will
- // not cache the stream, but just proxy the stream to backend.
- type srsRTMPServer struct {
- // The TCP listener for RTMP server.
- listener *net.TCPListener
- // The random number generator.
- rd *rand.Rand
- // The wait group for all goroutines.
- wg sync.WaitGroup
- }
- func NewSRSRTMPServer(opts ...func(*srsRTMPServer)) *srsRTMPServer {
- v := &srsRTMPServer{
- rd: rand.New(rand.NewSource(time.Now().UnixNano())),
- }
- for _, opt := range opts {
- opt(v)
- }
- return v
- }
- func (v *srsRTMPServer) Close() error {
- if v.listener != nil {
- v.listener.Close()
- }
- v.wg.Wait()
- return nil
- }
- func (v *srsRTMPServer) Run(ctx context.Context) error {
- endpoint := envRtmpServer()
- if !strings.Contains(endpoint, ":") {
- endpoint = ":" + endpoint
- }
- addr, err := net.ResolveTCPAddr("tcp", endpoint)
- if err != nil {
- return errors.Wrapf(err, "resolve rtmp addr %v", endpoint)
- }
- listener, err := net.ListenTCP("tcp", addr)
- if err != nil {
- return errors.Wrapf(err, "listen rtmp addr %v", addr)
- }
- v.listener = listener
- logger.Df(ctx, "RTMP server listen at %v", addr)
- v.wg.Add(1)
- go func() {
- defer v.wg.Done()
- for {
- conn, err := v.listener.AcceptTCP()
- if err != nil {
- if ctx.Err() != context.Canceled {
- // TODO: If RTMP server closed unexpectedly, we should notice the main loop to quit.
- logger.Wf(ctx, "RTMP server accept err %+v", err)
- } else {
- logger.Df(ctx, "RTMP server done")
- }
- return
- }
- v.wg.Add(1)
- go func(ctx context.Context, conn *net.TCPConn) {
- defer v.wg.Done()
- defer conn.Close()
- handleErr := func(err error) {
- if isPeerClosedError(err) {
- logger.Df(ctx, "RTMP peer is closed")
- } else {
- logger.Wf(ctx, "RTMP serve err %+v", err)
- }
- }
- rc := NewRTMPConnection(func(client *RTMPConnection) {
- client.rd = v.rd
- })
- if err := rc.serve(ctx, conn); err != nil {
- handleErr(err)
- } else {
- logger.Df(ctx, "RTMP client done")
- }
- }(logger.WithContext(ctx), conn)
- }
- }()
- return nil
- }
- // RTMPConnection is an RTMP streaming connection. There is no state need to be sync between
- // proxy servers.
- //
- // When we got an RTMP request, we will parse the stream URL from the RTMP publish or play request,
- // then proxy to the corresponding backend server. All state is in the RTMP request, so this
- // connection is stateless.
- type RTMPConnection struct {
- // The random number generator.
- rd *rand.Rand
- }
- func NewRTMPConnection(opts ...func(*RTMPConnection)) *RTMPConnection {
- v := &RTMPConnection{}
- for _, opt := range opts {
- opt(v)
- }
- return v
- }
- func (v *RTMPConnection) serve(ctx context.Context, conn *net.TCPConn) error {
- logger.Df(ctx, "Got RTMP client from %v", conn.RemoteAddr())
- // If any goroutine quit, cancel another one.
- parentCtx := ctx
- ctx, cancel := context.WithCancel(ctx)
- defer cancel()
- var backend *RTMPClientToBackend
- if true {
- go func() {
- <-ctx.Done()
- conn.Close()
- if backend != nil {
- backend.Close()
- }
- }()
- }
- // Simple handshake with client.
- hs := rtmp.NewHandshake(v.rd)
- if _, err := hs.ReadC0S0(conn); err != nil {
- return errors.Wrapf(err, "read c0")
- }
- if _, err := hs.ReadC1S1(conn); err != nil {
- return errors.Wrapf(err, "read c1")
- }
- if err := hs.WriteC0S0(conn); err != nil {
- return errors.Wrapf(err, "write s1")
- }
- if err := hs.WriteC1S1(conn); err != nil {
- return errors.Wrapf(err, "write s1")
- }
- if err := hs.WriteC2S2(conn, hs.C1S1()); err != nil {
- return errors.Wrapf(err, "write s2")
- }
- if _, err := hs.ReadC2S2(conn); err != nil {
- return errors.Wrapf(err, "read c2")
- }
- client := rtmp.NewProtocol(conn)
- logger.Df(ctx, "RTMP simple handshake done")
- // Expect RTMP connect command with tcUrl.
- var connectReq *rtmp.ConnectAppPacket
- if _, err := rtmp.ExpectPacket(ctx, client, &connectReq); err != nil {
- return errors.Wrapf(err, "expect connect req")
- }
- if true {
- ack := rtmp.NewWindowAcknowledgementSize()
- ack.AckSize = 2500000
- if err := client.WritePacket(ctx, ack, 0); err != nil {
- return errors.Wrapf(err, "write set ack size")
- }
- }
- if true {
- chunk := rtmp.NewSetChunkSize()
- chunk.ChunkSize = 128
- if err := client.WritePacket(ctx, chunk, 0); err != nil {
- return errors.Wrapf(err, "write set chunk size")
- }
- }
- connectRes := rtmp.NewConnectAppResPacket(connectReq.TransactionID)
- connectRes.CommandObject.Set("fmsVer", rtmp.NewAmf0String("FMS/3,5,3,888"))
- connectRes.CommandObject.Set("capabilities", rtmp.NewAmf0Number(127))
- connectRes.CommandObject.Set("mode", rtmp.NewAmf0Number(1))
- connectRes.Args.Set("level", rtmp.NewAmf0String("status"))
- connectRes.Args.Set("code", rtmp.NewAmf0String("NetConnection.Connect.Success"))
- connectRes.Args.Set("description", rtmp.NewAmf0String("Connection succeeded"))
- connectRes.Args.Set("objectEncoding", rtmp.NewAmf0Number(0))
- connectResData := rtmp.NewAmf0EcmaArray()
- connectResData.Set("version", rtmp.NewAmf0String("3,5,3,888"))
- connectResData.Set("srs_version", rtmp.NewAmf0String(Version()))
- connectResData.Set("srs_id", rtmp.NewAmf0String(logger.ContextID(ctx)))
- connectRes.Args.Set("data", connectResData)
- if err := client.WritePacket(ctx, connectRes, 0); err != nil {
- return errors.Wrapf(err, "write connect res")
- }
- tcUrl := connectReq.TcUrl()
- logger.Df(ctx, "RTMP connect app %v", tcUrl)
- // Expect RTMP command to identify the client, a publisher or viewer.
- var currentStreamID, nextStreamID int
- var streamName string
- var clientType RTMPClientType
- for clientType == "" {
- var identifyReq rtmp.Packet
- if _, err := rtmp.ExpectPacket(ctx, client, &identifyReq); err != nil {
- return errors.Wrapf(err, "expect identify req")
- }
- var response rtmp.Packet
- switch pkt := identifyReq.(type) {
- case *rtmp.CallPacket:
- if pkt.CommandName == "createStream" {
- identifyRes := rtmp.NewCreateStreamResPacket(pkt.TransactionID)
- response = identifyRes
- nextStreamID = 1
- identifyRes.StreamID = *rtmp.NewAmf0Number(float64(nextStreamID))
- } else if pkt.CommandName == "getStreamLength" {
- // Ignore and do not reply these packets.
- } else {
- // For releaseStream, FCPublish, etc.
- identifyRes := rtmp.NewCallPacket()
- response = identifyRes
- identifyRes.TransactionID = pkt.TransactionID
- identifyRes.CommandName = "_result"
- identifyRes.CommandObject = rtmp.NewAmf0Null()
- identifyRes.Args = rtmp.NewAmf0Undefined()
- }
- case *rtmp.PublishPacket:
- streamName = string(pkt.StreamName)
- clientType = RTMPClientTypePublisher
- identifyRes := rtmp.NewCallPacket()
- response = identifyRes
- identifyRes.CommandName = "onFCPublish"
- identifyRes.CommandObject = rtmp.NewAmf0Null()
- data := rtmp.NewAmf0Object()
- data.Set("code", rtmp.NewAmf0String("NetStream.Publish.Start"))
- data.Set("description", rtmp.NewAmf0String("Started publishing stream."))
- identifyRes.Args = data
- case *rtmp.PlayPacket:
- streamName = string(pkt.StreamName)
- clientType = RTMPClientTypeViewer
- identifyRes := rtmp.NewCallPacket()
- response = identifyRes
- identifyRes.CommandName = "onStatus"
- identifyRes.CommandObject = rtmp.NewAmf0Null()
- data := rtmp.NewAmf0Object()
- data.Set("level", rtmp.NewAmf0String("status"))
- data.Set("code", rtmp.NewAmf0String("NetStream.Play.Reset"))
- data.Set("description", rtmp.NewAmf0String("Playing and resetting stream."))
- data.Set("details", rtmp.NewAmf0String("stream"))
- data.Set("clientid", rtmp.NewAmf0String("ASAICiss"))
- identifyRes.Args = data
- }
- if response != nil {
- if err := client.WritePacket(ctx, response, currentStreamID); err != nil {
- return errors.Wrapf(err, "write identify res for req=%v, stream=%v",
- identifyReq, currentStreamID)
- }
- }
- // Update the stream ID for next request.
- currentStreamID = nextStreamID
- }
- logger.Df(ctx, "RTMP identify tcUrl=%v, stream=%v, id=%v, type=%v",
- tcUrl, streamName, currentStreamID, clientType)
- // Find a backend SRS server to proxy the RTMP stream.
- backend = NewRTMPClientToBackend(func(client *RTMPClientToBackend) {
- client.rd, client.typ = v.rd, clientType
- })
- defer backend.Close()
- if err := backend.Connect(ctx, tcUrl, streamName); err != nil {
- return errors.Wrapf(err, "connect backend, tcUrl=%v, stream=%v", tcUrl, streamName)
- }
- // Start the streaming.
- if clientType == RTMPClientTypePublisher {
- identifyRes := rtmp.NewCallPacket()
- identifyRes.CommandName = "onStatus"
- identifyRes.CommandObject = rtmp.NewAmf0Null()
- data := rtmp.NewAmf0Object()
- data.Set("level", rtmp.NewAmf0String("status"))
- data.Set("code", rtmp.NewAmf0String("NetStream.Publish.Start"))
- data.Set("description", rtmp.NewAmf0String("Started publishing stream."))
- data.Set("clientid", rtmp.NewAmf0String("ASAICiss"))
- identifyRes.Args = data
- if err := client.WritePacket(ctx, identifyRes, currentStreamID); err != nil {
- return errors.Wrapf(err, "start publish")
- }
- } else if clientType == RTMPClientTypeViewer {
- identifyRes := rtmp.NewCallPacket()
- identifyRes.CommandName = "onStatus"
- identifyRes.CommandObject = rtmp.NewAmf0Null()
- data := rtmp.NewAmf0Object()
- data.Set("level", rtmp.NewAmf0String("status"))
- data.Set("code", rtmp.NewAmf0String("NetStream.Play.Start"))
- data.Set("description", rtmp.NewAmf0String("Started playing stream."))
- data.Set("details", rtmp.NewAmf0String("stream"))
- data.Set("clientid", rtmp.NewAmf0String("ASAICiss"))
- identifyRes.Args = data
- if err := client.WritePacket(ctx, identifyRes, currentStreamID); err != nil {
- return errors.Wrapf(err, "start play")
- }
- }
- logger.Df(ctx, "RTMP start streaming")
- // For all proxy goroutines.
- var wg sync.WaitGroup
- defer wg.Wait()
- // Proxy all message from backend to client.
- wg.Add(1)
- var r0 error
- go func() {
- defer wg.Done()
- defer cancel()
- r0 = func() error {
- for {
- m, err := backend.client.ReadMessage(ctx)
- if err != nil {
- return errors.Wrapf(err, "read message")
- }
- //logger.Df(ctx, "client<- %v %v %vB", m.MessageType, m.Timestamp, len(m.Payload))
- // TODO: Update the stream ID if not the same.
- if err := client.WriteMessage(ctx, m); err != nil {
- return errors.Wrapf(err, "write message")
- }
- }
- }()
- }()
- // Proxy all messages from client to backend.
- wg.Add(1)
- var r1 error
- go func() {
- defer wg.Done()
- defer cancel()
- r1 = func() error {
- for {
- m, err := client.ReadMessage(ctx)
- if err != nil {
- return errors.Wrapf(err, "read message")
- }
- //logger.Df(ctx, "client-> %v %v %vB", m.MessageType, m.Timestamp, len(m.Payload))
- // TODO: Update the stream ID if not the same.
- if err := backend.client.WriteMessage(ctx, m); err != nil {
- return errors.Wrapf(err, "write message")
- }
- }
- }()
- }()
- // Wait until all goroutine quit.
- wg.Wait()
- // Reset the error if caused by another goroutine.
- if r0 != nil {
- return errors.Wrapf(r0, "proxy backend->client")
- }
- if r1 != nil {
- return errors.Wrapf(r1, "proxy client->backend")
- }
- return parentCtx.Err()
- }
- type RTMPClientType string
- const (
- RTMPClientTypePublisher RTMPClientType = "publisher"
- RTMPClientTypeViewer RTMPClientType = "viewer"
- )
- // RTMPClientToBackend is a RTMP client to proxy the RTMP stream to backend.
- type RTMPClientToBackend struct {
- // The random number generator.
- rd *rand.Rand
- // The underlayer tcp client.
- tcpConn *net.TCPConn
- // The RTMP protocol client.
- client *rtmp.Protocol
- // The stream type.
- typ RTMPClientType
- }
- func NewRTMPClientToBackend(opts ...func(*RTMPClientToBackend)) *RTMPClientToBackend {
- v := &RTMPClientToBackend{}
- for _, opt := range opts {
- opt(v)
- }
- return v
- }
- func (v *RTMPClientToBackend) Close() error {
- if v.tcpConn != nil {
- v.tcpConn.Close()
- }
- return nil
- }
- func (v *RTMPClientToBackend) Connect(ctx context.Context, tcUrl, streamName string) error {
- // Build the stream URL in vhost/app/stream schema.
- streamURL, err := buildStreamURL(fmt.Sprintf("%v/%v", tcUrl, streamName))
- if err != nil {
- return errors.Wrapf(err, "build stream url %v/%v", tcUrl, streamName)
- }
- // Pick a backend SRS server to proxy the RTMP stream.
- backend, err := srsLoadBalancer.Pick(ctx, streamURL)
- if err != nil {
- return errors.Wrapf(err, "pick backend for %v", streamURL)
- }
- // Parse RTMP port from backend.
- if len(backend.RTMP) == 0 {
- return errors.Errorf("no rtmp server %+v for %v", backend, streamURL)
- }
- var rtmpPort int
- if iv, err := strconv.ParseInt(backend.RTMP[0], 10, 64); err != nil {
- return errors.Wrapf(err, "parse backend %+v rtmp port %v", backend, backend.RTMP[0])
- } else {
- rtmpPort = int(iv)
- }
- // Connect to backend SRS server via TCP client.
- addr := &net.TCPAddr{IP: net.ParseIP(backend.IP), Port: rtmpPort}
- c, err := net.DialTCP("tcp", nil, addr)
- if err != nil {
- return errors.Wrapf(err, "dial backend addr=%v, srs=%v", addr, backend)
- }
- v.tcpConn = c
- hs := rtmp.NewHandshake(v.rd)
- client := rtmp.NewProtocol(c)
- v.client = client
- // Simple RTMP handshake with server.
- if err := hs.WriteC0S0(c); err != nil {
- return errors.Wrapf(err, "write c0")
- }
- if err := hs.WriteC1S1(c); err != nil {
- return errors.Wrapf(err, "write c1")
- }
- if _, err = hs.ReadC0S0(c); err != nil {
- return errors.Wrapf(err, "read s0")
- }
- if _, err := hs.ReadC1S1(c); err != nil {
- return errors.Wrapf(err, "read s1")
- }
- if _, err = hs.ReadC2S2(c); err != nil {
- return errors.Wrapf(err, "read c2")
- }
- logger.Df(ctx, "backend simple handshake done, server=%v", addr)
- if err := hs.WriteC2S2(c, hs.C1S1()); err != nil {
- return errors.Wrapf(err, "write c2")
- }
- // Connect RTMP app on tcUrl with server.
- if true {
- connectApp := rtmp.NewConnectAppPacket()
- connectApp.CommandObject.Set("tcUrl", rtmp.NewAmf0String(tcUrl))
- if err := client.WritePacket(ctx, connectApp, 1); err != nil {
- return errors.Wrapf(err, "write connect app")
- }
- }
- if true {
- var connectAppRes *rtmp.ConnectAppResPacket
- if _, err := rtmp.ExpectPacket(ctx, client, &connectAppRes); err != nil {
- return errors.Wrapf(err, "expect connect app res")
- }
- logger.Df(ctx, "backend connect RTMP app, tcUrl=%v, id=%v", tcUrl, connectAppRes.SrsID())
- }
- // Play or view RTMP stream with server.
- if v.typ == RTMPClientTypeViewer {
- return v.play(ctx, client, streamName)
- }
- // Publish RTMP stream with server.
- return v.publish(ctx, client, streamName)
- }
- func (v *RTMPClientToBackend) publish(ctx context.Context, client *rtmp.Protocol, streamName string) error {
- if true {
- identifyReq := rtmp.NewCallPacket()
- identifyReq.CommandName = "releaseStream"
- identifyReq.TransactionID = 2
- identifyReq.CommandObject = rtmp.NewAmf0Null()
- identifyReq.Args = rtmp.NewAmf0String(streamName)
- if err := client.WritePacket(ctx, identifyReq, 0); err != nil {
- return errors.Wrapf(err, "releaseStream")
- }
- }
- for {
- var identifyRes *rtmp.CallPacket
- if _, err := rtmp.ExpectPacket(ctx, client, &identifyRes); err != nil {
- return errors.Wrapf(err, "expect releaseStream res")
- }
- if identifyRes.CommandName == "_result" {
- break
- }
- }
- if true {
- identifyReq := rtmp.NewCallPacket()
- identifyReq.CommandName = "FCPublish"
- identifyReq.TransactionID = 3
- identifyReq.CommandObject = rtmp.NewAmf0Null()
- identifyReq.Args = rtmp.NewAmf0String(streamName)
- if err := client.WritePacket(ctx, identifyReq, 0); err != nil {
- return errors.Wrapf(err, "FCPublish")
- }
- }
- for {
- var identifyRes *rtmp.CallPacket
- if _, err := rtmp.ExpectPacket(ctx, client, &identifyRes); err != nil {
- return errors.Wrapf(err, "expect FCPublish res")
- }
- if identifyRes.CommandName == "_result" {
- break
- }
- }
- var currentStreamID int
- if true {
- createStream := rtmp.NewCreateStreamPacket()
- createStream.TransactionID = 4
- createStream.CommandObject = rtmp.NewAmf0Null()
- if err := client.WritePacket(ctx, createStream, 0); err != nil {
- return errors.Wrapf(err, "createStream")
- }
- }
- for {
- var identifyRes *rtmp.CreateStreamResPacket
- if _, err := rtmp.ExpectPacket(ctx, client, &identifyRes); err != nil {
- return errors.Wrapf(err, "expect createStream res")
- }
- if sid := identifyRes.StreamID; sid != 0 {
- currentStreamID = int(sid)
- break
- }
- }
- if true {
- publishStream := rtmp.NewPublishPacket()
- publishStream.TransactionID = 5
- publishStream.CommandObject = rtmp.NewAmf0Null()
- publishStream.StreamName = *rtmp.NewAmf0String(streamName)
- publishStream.StreamType = *rtmp.NewAmf0String("live")
- if err := client.WritePacket(ctx, publishStream, currentStreamID); err != nil {
- return errors.Wrapf(err, "publish")
- }
- }
- for {
- var identifyRes *rtmp.CallPacket
- if _, err := rtmp.ExpectPacket(ctx, client, &identifyRes); err != nil {
- return errors.Wrapf(err, "expect publish res")
- }
- // Ignore onFCPublish, expect onStatus(NetStream.Publish.Start).
- if identifyRes.CommandName == "onStatus" {
- if data := rtmp.NewAmf0Converter(identifyRes.Args).ToObject(); data == nil {
- return errors.Errorf("onStatus args not object")
- } else if code := rtmp.NewAmf0Converter(data.Get("code")).ToString(); code == nil {
- return errors.Errorf("onStatus code not string")
- } else if *code != "NetStream.Publish.Start" {
- return errors.Errorf("onStatus code=%v not NetStream.Publish.Start", *code)
- }
- break
- }
- }
- logger.Df(ctx, "backend publish stream=%v, sid=%v", streamName, currentStreamID)
- return nil
- }
- func (v *RTMPClientToBackend) play(ctx context.Context, client *rtmp.Protocol, streamName string) error {
- var currentStreamID int
- if true {
- createStream := rtmp.NewCreateStreamPacket()
- createStream.TransactionID = 4
- createStream.CommandObject = rtmp.NewAmf0Null()
- if err := client.WritePacket(ctx, createStream, 0); err != nil {
- return errors.Wrapf(err, "createStream")
- }
- }
- for {
- var identifyRes *rtmp.CreateStreamResPacket
- if _, err := rtmp.ExpectPacket(ctx, client, &identifyRes); err != nil {
- return errors.Wrapf(err, "expect createStream res")
- }
- if sid := identifyRes.StreamID; sid != 0 {
- currentStreamID = int(sid)
- break
- }
- }
- playStream := rtmp.NewPlayPacket()
- playStream.StreamName = *rtmp.NewAmf0String(streamName)
- if err := client.WritePacket(ctx, playStream, currentStreamID); err != nil {
- return errors.Wrapf(err, "play")
- }
- for {
- var identifyRes *rtmp.CallPacket
- if _, err := rtmp.ExpectPacket(ctx, client, &identifyRes); err != nil {
- return errors.Wrapf(err, "expect releaseStream res")
- }
- if identifyRes.CommandName == "onStatus" && identifyRes.ArgsCode() == "NetStream.Play.Start" {
- break
- }
- }
- return nil
- }
|