123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553 |
- // Copyright (c) 2024 Winlin
- //
- // SPDX-License-Identifier: MIT
- package main
- import (
- "context"
- "encoding/json"
- "fmt"
- "math/rand"
- "os"
- "strconv"
- "strings"
- "time"
- // Use v8 because we use Go 1.16+, while v9 requires Go 1.18+
- "github.com/go-redis/redis/v8"
- "srs-proxy/errors"
- "srs-proxy/logger"
- "srs-proxy/sync"
- )
- // If server heartbeat in this duration, it's alive.
- const srsServerAliveDuration = 300 * time.Second
- // If HLS streaming update in this duration, it's alive.
- const srsHLSAliveDuration = 120 * time.Second
- // If WebRTC streaming update in this duration, it's alive.
- const srsRTCAliveDuration = 120 * time.Second
- type SRSServer struct {
- // The server IP.
- IP string `json:"ip,omitempty"`
- // The server device ID, configured by user.
- DeviceID string `json:"device_id,omitempty"`
- // The server id of SRS, store in file, may not change, mandatory.
- ServerID string `json:"server_id,omitempty"`
- // The service id of SRS, always change when restarted, mandatory.
- ServiceID string `json:"service_id,omitempty"`
- // The process id of SRS, always change when restarted, mandatory.
- PID string `json:"pid,omitempty"`
- // The RTMP listen endpoints.
- RTMP []string `json:"rtmp,omitempty"`
- // The HTTP Stream listen endpoints.
- HTTP []string `json:"http,omitempty"`
- // The HTTP API listen endpoints.
- API []string `json:"api,omitempty"`
- // The SRT server listen endpoints.
- SRT []string `json:"srt,omitempty"`
- // The RTC server listen endpoints.
- RTC []string `json:"rtc,omitempty"`
- // Last update time.
- UpdatedAt time.Time `json:"update_at,omitempty"`
- }
- func (v *SRSServer) ID() string {
- return fmt.Sprintf("%v-%v-%v", v.ServerID, v.ServiceID, v.PID)
- }
- func (v *SRSServer) String() string {
- return fmt.Sprintf("%v", v)
- }
- func (v *SRSServer) Format(f fmt.State, c rune) {
- switch c {
- case 'v', 's':
- if f.Flag('+') {
- var sb strings.Builder
- sb.WriteString(fmt.Sprintf("pid=%v, server=%v, service=%v", v.PID, v.ServerID, v.ServiceID))
- if v.DeviceID != "" {
- sb.WriteString(fmt.Sprintf(", device=%v", v.DeviceID))
- }
- if len(v.RTMP) > 0 {
- sb.WriteString(fmt.Sprintf(", rtmp=[%v]", strings.Join(v.RTMP, ",")))
- }
- if len(v.HTTP) > 0 {
- sb.WriteString(fmt.Sprintf(", http=[%v]", strings.Join(v.HTTP, ",")))
- }
- if len(v.API) > 0 {
- sb.WriteString(fmt.Sprintf(", api=[%v]", strings.Join(v.API, ",")))
- }
- if len(v.SRT) > 0 {
- sb.WriteString(fmt.Sprintf(", srt=[%v]", strings.Join(v.SRT, ",")))
- }
- if len(v.RTC) > 0 {
- sb.WriteString(fmt.Sprintf(", rtc=[%v]", strings.Join(v.RTC, ",")))
- }
- sb.WriteString(fmt.Sprintf(", update=%v", v.UpdatedAt.Format("2006-01-02 15:04:05.999")))
- fmt.Fprintf(f, "SRS ip=%v, id=%v, %v", v.IP, v.ID(), sb.String())
- } else {
- fmt.Fprintf(f, "SRS ip=%v, id=%v", v.IP, v.ID())
- }
- default:
- fmt.Fprintf(f, "%v, fmt=%%%c", v, c)
- }
- }
- func NewSRSServer(opts ...func(*SRSServer)) *SRSServer {
- v := &SRSServer{}
- for _, opt := range opts {
- opt(v)
- }
- return v
- }
- // NewDefaultSRSForDebugging initialize the default SRS media server, for debugging only.
- func NewDefaultSRSForDebugging() (*SRSServer, error) {
- if envDefaultBackendEnabled() != "on" {
- return nil, nil
- }
- if envDefaultBackendIP() == "" {
- return nil, fmt.Errorf("empty default backend ip")
- }
- if envDefaultBackendRTMP() == "" {
- return nil, fmt.Errorf("empty default backend rtmp")
- }
- server := NewSRSServer(func(srs *SRSServer) {
- srs.IP = envDefaultBackendIP()
- srs.RTMP = []string{envDefaultBackendRTMP()}
- srs.ServerID = fmt.Sprintf("default-%v", logger.GenerateContextID())
- srs.ServiceID = logger.GenerateContextID()
- srs.PID = fmt.Sprintf("%v", os.Getpid())
- srs.UpdatedAt = time.Now()
- })
- if envDefaultBackendHttp() != "" {
- server.HTTP = []string{envDefaultBackendHttp()}
- }
- if envDefaultBackendAPI() != "" {
- server.API = []string{envDefaultBackendAPI()}
- }
- if envDefaultBackendRTC() != "" {
- server.RTC = []string{envDefaultBackendRTC()}
- }
- if envDefaultBackendSRT() != "" {
- server.SRT = []string{envDefaultBackendSRT()}
- }
- return server, nil
- }
- // SRSLoadBalancer is the interface to load balance the SRS servers.
- type SRSLoadBalancer interface {
- // Initialize the load balancer.
- Initialize(ctx context.Context) error
- // Update the backer server.
- Update(ctx context.Context, server *SRSServer) error
- // Pick a backend server for the specified stream URL.
- Pick(ctx context.Context, streamURL string) (*SRSServer, error)
- // Load or store the HLS streaming for the specified stream URL.
- LoadOrStoreHLS(ctx context.Context, streamURL string, value *HLSPlayStream) (*HLSPlayStream, error)
- // Load the HLS streaming by SPBHID, the SRS Proxy Backend HLS ID.
- LoadHLSBySPBHID(ctx context.Context, spbhid string) (*HLSPlayStream, error)
- // Store the WebRTC streaming for the specified stream URL.
- StoreWebRTC(ctx context.Context, streamURL string, value *RTCConnection) error
- // Load the WebRTC streaming by ufrag, the ICE username.
- LoadWebRTCByUfrag(ctx context.Context, ufrag string) (*RTCConnection, error)
- }
- // srsLoadBalancer is the global SRS load balancer.
- var srsLoadBalancer SRSLoadBalancer
- // srsMemoryLoadBalancer stores state in memory.
- type srsMemoryLoadBalancer struct {
- // All available SRS servers, key is server ID.
- servers sync.Map[string, *SRSServer]
- // The picked server to servce client by specified stream URL, key is stream url.
- picked sync.Map[string, *SRSServer]
- // The HLS streaming, key is stream URL.
- hlsStreamURL sync.Map[string, *HLSPlayStream]
- // The HLS streaming, key is SPBHID.
- hlsSPBHID sync.Map[string, *HLSPlayStream]
- // The WebRTC streaming, key is stream URL.
- rtcStreamURL sync.Map[string, *RTCConnection]
- // The WebRTC streaming, key is ufrag.
- rtcUfrag sync.Map[string, *RTCConnection]
- }
- func NewMemoryLoadBalancer() SRSLoadBalancer {
- return &srsMemoryLoadBalancer{}
- }
- func (v *srsMemoryLoadBalancer) Initialize(ctx context.Context) error {
- if server, err := NewDefaultSRSForDebugging(); err != nil {
- return errors.Wrapf(err, "initialize default SRS")
- } else if server != nil {
- if err := v.Update(ctx, server); err != nil {
- return errors.Wrapf(err, "update default SRS %+v", server)
- }
- // Keep alive.
- go func() {
- for {
- select {
- case <-ctx.Done():
- return
- case <-time.After(30 * time.Second):
- if err := v.Update(ctx, server); err != nil {
- logger.Wf(ctx, "update default SRS %+v failed, %+v", server, err)
- }
- }
- }
- }()
- logger.Df(ctx, "MemoryLB: Initialize default SRS media server, %+v", server)
- }
- return nil
- }
- func (v *srsMemoryLoadBalancer) Update(ctx context.Context, server *SRSServer) error {
- v.servers.Store(server.ID(), server)
- return nil
- }
- func (v *srsMemoryLoadBalancer) Pick(ctx context.Context, streamURL string) (*SRSServer, error) {
- // Always proxy to the same server for the same stream URL.
- if server, ok := v.picked.Load(streamURL); ok {
- return server, nil
- }
- // Gather all servers that were alive within the last few seconds.
- var servers []*SRSServer
- v.servers.Range(func(key string, server *SRSServer) bool {
- if time.Since(server.UpdatedAt) < srsServerAliveDuration {
- servers = append(servers, server)
- }
- return true
- })
- // If no servers available, use all possible servers.
- if len(servers) == 0 {
- v.servers.Range(func(key string, server *SRSServer) bool {
- servers = append(servers, server)
- return true
- })
- }
- // No server found, failed.
- if len(servers) == 0 {
- return nil, fmt.Errorf("no server available for %v", streamURL)
- }
- // Pick a server randomly from servers.
- server := servers[rand.Intn(len(servers))]
- v.picked.Store(streamURL, server)
- return server, nil
- }
- func (v *srsMemoryLoadBalancer) LoadHLSBySPBHID(ctx context.Context, spbhid string) (*HLSPlayStream, error) {
- // Load the HLS streaming for the SPBHID, for TS files.
- if actual, ok := v.hlsSPBHID.Load(spbhid); !ok {
- return nil, errors.Errorf("no HLS streaming for SPBHID %v", spbhid)
- } else {
- return actual, nil
- }
- }
- func (v *srsMemoryLoadBalancer) LoadOrStoreHLS(ctx context.Context, streamURL string, value *HLSPlayStream) (*HLSPlayStream, error) {
- // Update the HLS streaming for the stream URL, for M3u8.
- actual, _ := v.hlsStreamURL.LoadOrStore(streamURL, value)
- if actual == nil {
- return nil, errors.Errorf("load or store HLS streaming for %v failed", streamURL)
- }
- // Update the HLS streaming for the SPBHID, for TS files.
- v.hlsSPBHID.Store(value.SRSProxyBackendHLSID, actual)
- return actual, nil
- }
- func (v *srsMemoryLoadBalancer) StoreWebRTC(ctx context.Context, streamURL string, value *RTCConnection) error {
- // Update the WebRTC streaming for the stream URL.
- v.rtcStreamURL.Store(streamURL, value)
- // Update the WebRTC streaming for the ufrag.
- v.rtcUfrag.Store(value.Ufrag, value)
- return nil
- }
- func (v *srsMemoryLoadBalancer) LoadWebRTCByUfrag(ctx context.Context, ufrag string) (*RTCConnection, error) {
- if actual, ok := v.rtcUfrag.Load(ufrag); !ok {
- return nil, errors.Errorf("no WebRTC streaming for ufrag %v", ufrag)
- } else {
- return actual, nil
- }
- }
- type srsRedisLoadBalancer struct {
- // The redis client sdk.
- rdb *redis.Client
- }
- func NewRedisLoadBalancer() SRSLoadBalancer {
- return &srsRedisLoadBalancer{}
- }
- func (v *srsRedisLoadBalancer) Initialize(ctx context.Context) error {
- redisDatabase, err := strconv.Atoi(envRedisDB())
- if err != nil {
- return errors.Wrapf(err, "invalid PROXY_REDIS_DB %v", envRedisDB())
- }
- rdb := redis.NewClient(&redis.Options{
- Addr: fmt.Sprintf("%v:%v", envRedisHost(), envRedisPort()),
- Password: envRedisPassword(),
- DB: redisDatabase,
- })
- v.rdb = rdb
- if err := rdb.Ping(ctx).Err(); err != nil {
- return errors.Wrapf(err, "unable to connect to redis %v", rdb.String())
- }
- logger.Df(ctx, "RedisLB: connected to redis %v ok", rdb.String())
- if server, err := NewDefaultSRSForDebugging(); err != nil {
- return errors.Wrapf(err, "initialize default SRS")
- } else if server != nil {
- if err := v.Update(ctx, server); err != nil {
- return errors.Wrapf(err, "update default SRS %+v", server)
- }
- // Keep alive.
- go func() {
- for {
- select {
- case <-ctx.Done():
- return
- case <-time.After(30 * time.Second):
- if err := v.Update(ctx, server); err != nil {
- logger.Wf(ctx, "update default SRS %+v failed, %+v", server, err)
- }
- }
- }
- }()
- logger.Df(ctx, "RedisLB: Initialize default SRS media server, %+v", server)
- }
- return nil
- }
- func (v *srsRedisLoadBalancer) Update(ctx context.Context, server *SRSServer) error {
- b, err := json.Marshal(server)
- if err != nil {
- return errors.Wrapf(err, "marshal server %+v", server)
- }
- key := v.redisKeyServer(server.ID())
- if err = v.rdb.Set(ctx, key, b, srsServerAliveDuration).Err(); err != nil {
- return errors.Wrapf(err, "set key=%v server %+v", key, server)
- }
- // Query all servers from redis, in json string.
- var serverKeys []string
- if b, err := v.rdb.Get(ctx, v.redisKeyServers()).Bytes(); err == nil {
- if err := json.Unmarshal(b, &serverKeys); err != nil {
- return errors.Wrapf(err, "unmarshal key=%v servers %v", v.redisKeyServers(), string(b))
- }
- }
- // Check each server expiration, if not exists in redis, remove from servers.
- for i := len(serverKeys) - 1; i >= 0; i-- {
- if _, err := v.rdb.Get(ctx, serverKeys[i]).Bytes(); err != nil {
- serverKeys = append(serverKeys[:i], serverKeys[i+1:]...)
- }
- }
- // Add server to servers if not exists.
- var found bool
- for _, serverKey := range serverKeys {
- if serverKey == key {
- found = true
- break
- }
- }
- if !found {
- serverKeys = append(serverKeys, key)
- }
- // Update all servers to redis.
- b, err = json.Marshal(serverKeys)
- if err != nil {
- return errors.Wrapf(err, "marshal servers %+v", serverKeys)
- }
- if err = v.rdb.Set(ctx, v.redisKeyServers(), b, 0).Err(); err != nil {
- return errors.Wrapf(err, "set key=%v servers %+v", v.redisKeyServers(), serverKeys)
- }
- return nil
- }
- func (v *srsRedisLoadBalancer) Pick(ctx context.Context, streamURL string) (*SRSServer, error) {
- key := fmt.Sprintf("srs-proxy-url:%v", streamURL)
- // Always proxy to the same server for the same stream URL.
- if serverKey, err := v.rdb.Get(ctx, key).Result(); err == nil {
- // If server not exists, ignore and pick another server for the stream URL.
- if b, err := v.rdb.Get(ctx, serverKey).Bytes(); err == nil && len(b) > 0 {
- var server SRSServer
- if err := json.Unmarshal(b, &server); err != nil {
- return nil, errors.Wrapf(err, "unmarshal key=%v server %v", key, string(b))
- }
- // TODO: If server fail, we should migrate the streams to another server.
- return &server, nil
- }
- }
- // Query all servers from redis, in json string.
- var serverKeys []string
- if b, err := v.rdb.Get(ctx, v.redisKeyServers()).Bytes(); err == nil {
- if err := json.Unmarshal(b, &serverKeys); err != nil {
- return nil, errors.Wrapf(err, "unmarshal key=%v servers %v", v.redisKeyServers(), string(b))
- }
- }
- // No server found, failed.
- if len(serverKeys) == 0 {
- return nil, fmt.Errorf("no server available for %v", streamURL)
- }
- // All server should be alive, if not, should have been removed by redis. So we only
- // random pick one that is always available.
- var serverKey string
- var server SRSServer
- for i := 0; i < 3; i++ {
- tryServerKey := serverKeys[rand.Intn(len(serverKeys))]
- b, err := v.rdb.Get(ctx, tryServerKey).Bytes()
- if err == nil && len(b) > 0 {
- if err := json.Unmarshal(b, &server); err != nil {
- return nil, errors.Wrapf(err, "unmarshal key=%v server %v", serverKey, string(b))
- }
- serverKey = tryServerKey
- break
- }
- }
- if serverKey == "" {
- return nil, errors.Errorf("no server available in %v for %v", serverKeys, streamURL)
- }
- // Update the picked server for the stream URL.
- if err := v.rdb.Set(ctx, key, []byte(serverKey), 0).Err(); err != nil {
- return nil, errors.Wrapf(err, "set key=%v server %v", key, serverKey)
- }
- return &server, nil
- }
- func (v *srsRedisLoadBalancer) LoadHLSBySPBHID(ctx context.Context, spbhid string) (*HLSPlayStream, error) {
- key := v.redisKeySPBHID(spbhid)
- b, err := v.rdb.Get(ctx, key).Bytes()
- if err != nil {
- return nil, errors.Wrapf(err, "get key=%v HLS", key)
- }
- var actual HLSPlayStream
- if err := json.Unmarshal(b, &actual); err != nil {
- return nil, errors.Wrapf(err, "unmarshal key=%v HLS %v", key, string(b))
- }
- return &actual, nil
- }
- func (v *srsRedisLoadBalancer) LoadOrStoreHLS(ctx context.Context, streamURL string, value *HLSPlayStream) (*HLSPlayStream, error) {
- b, err := json.Marshal(value)
- if err != nil {
- return nil, errors.Wrapf(err, "marshal HLS %v", value)
- }
- key := v.redisKeyHLS(streamURL)
- if err = v.rdb.Set(ctx, key, b, srsHLSAliveDuration).Err(); err != nil {
- return nil, errors.Wrapf(err, "set key=%v HLS %v", key, value)
- }
- key2 := v.redisKeySPBHID(value.SRSProxyBackendHLSID)
- if err := v.rdb.Set(ctx, key2, b, srsHLSAliveDuration).Err(); err != nil {
- return nil, errors.Wrapf(err, "set key=%v HLS %v", key2, value)
- }
- // Query the HLS streaming from redis.
- b2, err := v.rdb.Get(ctx, key).Bytes()
- if err != nil {
- return nil, errors.Wrapf(err, "get key=%v HLS", key)
- }
- var actual HLSPlayStream
- if err := json.Unmarshal(b2, &actual); err != nil {
- return nil, errors.Wrapf(err, "unmarshal key=%v HLS %v", key, string(b2))
- }
- return &actual, nil
- }
- func (v *srsRedisLoadBalancer) StoreWebRTC(ctx context.Context, streamURL string, value *RTCConnection) error {
- b, err := json.Marshal(value)
- if err != nil {
- return errors.Wrapf(err, "marshal WebRTC %v", value)
- }
- key := v.redisKeyRTC(streamURL)
- if err = v.rdb.Set(ctx, key, b, srsRTCAliveDuration).Err(); err != nil {
- return errors.Wrapf(err, "set key=%v WebRTC %v", key, value)
- }
- key2 := v.redisKeyUfrag(value.Ufrag)
- if err := v.rdb.Set(ctx, key2, b, srsRTCAliveDuration).Err(); err != nil {
- return errors.Wrapf(err, "set key=%v WebRTC %v", key2, value)
- }
- return nil
- }
- func (v *srsRedisLoadBalancer) LoadWebRTCByUfrag(ctx context.Context, ufrag string) (*RTCConnection, error) {
- key := v.redisKeyUfrag(ufrag)
- b, err := v.rdb.Get(ctx, key).Bytes()
- if err != nil {
- return nil, errors.Wrapf(err, "get key=%v WebRTC", key)
- }
- var actual RTCConnection
- if err := json.Unmarshal(b, &actual); err != nil {
- return nil, errors.Wrapf(err, "unmarshal key=%v WebRTC %v", key, string(b))
- }
- return &actual, nil
- }
- func (v *srsRedisLoadBalancer) redisKeyUfrag(ufrag string) string {
- return fmt.Sprintf("srs-proxy-ufrag:%v", ufrag)
- }
- func (v *srsRedisLoadBalancer) redisKeyRTC(streamURL string) string {
- return fmt.Sprintf("srs-proxy-rtc:%v", streamURL)
- }
- func (v *srsRedisLoadBalancer) redisKeySPBHID(spbhid string) string {
- return fmt.Sprintf("srs-proxy-spbhid:%v", spbhid)
- }
- func (v *srsRedisLoadBalancer) redisKeyHLS(streamURL string) string {
- return fmt.Sprintf("srs-proxy-hls:%v", streamURL)
- }
- func (v *srsRedisLoadBalancer) redisKeyServer(serverID string) string {
- return fmt.Sprintf("srs-proxy-server:%v", serverID)
- }
- func (v *srsRedisLoadBalancer) redisKeyServers() string {
- return fmt.Sprintf("srs-proxy-all-servers")
- }
|