server.go 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. // Package turn contains the public API for pion/turn, a toolkit for building TURN clients and servers
  2. package turn
  3. import (
  4. "fmt"
  5. "net"
  6. "sync"
  7. "time"
  8. "github.com/pion/logging"
  9. "github.com/pion/turn/v2/internal/allocation"
  10. "github.com/pion/turn/v2/internal/proto"
  11. "github.com/pion/turn/v2/internal/server"
  12. )
  13. const (
  14. defaultInboundMTU = 1600
  15. )
  16. // Server is an instance of the Pion TURN Server
  17. type Server struct {
  18. log logging.LeveledLogger
  19. authHandler AuthHandler
  20. realm string
  21. channelBindTimeout time.Duration
  22. nonces *sync.Map
  23. packetConnConfigs []PacketConnConfig
  24. listenerConfigs []ListenerConfig
  25. allocationManagers []*allocation.Manager
  26. inboundMTU int
  27. }
  28. // NewServer creates the Pion TURN server
  29. //
  30. //nolint:gocognit
  31. func NewServer(config ServerConfig) (*Server, error) {
  32. if err := config.validate(); err != nil {
  33. return nil, err
  34. }
  35. loggerFactory := config.LoggerFactory
  36. if loggerFactory == nil {
  37. loggerFactory = logging.NewDefaultLoggerFactory()
  38. }
  39. mtu := defaultInboundMTU
  40. if config.InboundMTU != 0 {
  41. mtu = config.InboundMTU
  42. }
  43. s := &Server{
  44. log: loggerFactory.NewLogger("turn"),
  45. authHandler: config.AuthHandler,
  46. realm: config.Realm,
  47. channelBindTimeout: config.ChannelBindTimeout,
  48. packetConnConfigs: config.PacketConnConfigs,
  49. listenerConfigs: config.ListenerConfigs,
  50. nonces: &sync.Map{},
  51. inboundMTU: mtu,
  52. }
  53. if s.channelBindTimeout == 0 {
  54. s.channelBindTimeout = proto.DefaultLifetime
  55. }
  56. for _, cfg := range s.packetConnConfigs {
  57. am, err := s.createAllocationManager(cfg.RelayAddressGenerator, cfg.PermissionHandler)
  58. if err != nil {
  59. return nil, fmt.Errorf("failed to create AllocationManager: %w", err)
  60. }
  61. go s.readPacketConn(cfg, am)
  62. }
  63. for _, cfg := range s.listenerConfigs {
  64. am, err := s.createAllocationManager(cfg.RelayAddressGenerator, cfg.PermissionHandler)
  65. if err != nil {
  66. return nil, fmt.Errorf("failed to create AllocationManager: %w", err)
  67. }
  68. go s.readListener(cfg, am)
  69. }
  70. return s, nil
  71. }
  72. // AllocationCount returns the number of active allocations. It can be used to drain the server before closing
  73. func (s *Server) AllocationCount() int {
  74. allocs := 0
  75. for _, am := range s.allocationManagers {
  76. allocs += am.AllocationCount()
  77. }
  78. return allocs
  79. }
  80. // Close stops the TURN Server. It cleans up any associated state and closes all connections it is managing
  81. func (s *Server) Close() error {
  82. var errors []error
  83. for _, cfg := range s.packetConnConfigs {
  84. if err := cfg.PacketConn.Close(); err != nil {
  85. errors = append(errors, err)
  86. }
  87. }
  88. for _, cfg := range s.listenerConfigs {
  89. if err := cfg.Listener.Close(); err != nil {
  90. errors = append(errors, err)
  91. }
  92. }
  93. if len(errors) == 0 {
  94. return nil
  95. }
  96. err := errFailedToClose
  97. for _, e := range errors {
  98. err = fmt.Errorf("%s; close error (%w) ", err, e)
  99. }
  100. return err
  101. }
  102. func (s *Server) readPacketConn(p PacketConnConfig, am *allocation.Manager) {
  103. s.readLoop(p.PacketConn, am)
  104. if err := am.Close(); err != nil {
  105. s.log.Errorf("Failed to close AllocationManager: %s", err)
  106. }
  107. }
  108. func (s *Server) readListener(l ListenerConfig, am *allocation.Manager) {
  109. defer func() {
  110. if err := am.Close(); err != nil {
  111. s.log.Errorf("Failed to close AllocationManager: %s", err)
  112. }
  113. }()
  114. for {
  115. conn, err := l.Listener.Accept()
  116. if err != nil {
  117. s.log.Debugf("Failed to accept: %s", err)
  118. return
  119. }
  120. go s.readLoop(NewSTUNConn(conn), am)
  121. }
  122. }
  123. func (s *Server) createAllocationManager(addrGenerator RelayAddressGenerator, handler PermissionHandler) (*allocation.Manager, error) {
  124. if handler == nil {
  125. handler = DefaultPermissionHandler
  126. }
  127. am, err := allocation.NewManager(allocation.ManagerConfig{
  128. AllocatePacketConn: addrGenerator.AllocatePacketConn,
  129. AllocateConn: addrGenerator.AllocateConn,
  130. PermissionHandler: handler,
  131. LeveledLogger: s.log,
  132. })
  133. if err != nil {
  134. return am, err
  135. }
  136. s.allocationManagers = append(s.allocationManagers, am)
  137. return am, err
  138. }
  139. func (s *Server) readLoop(p net.PacketConn, allocationManager *allocation.Manager) {
  140. buf := make([]byte, s.inboundMTU)
  141. for {
  142. n, addr, err := p.ReadFrom(buf)
  143. switch {
  144. case err != nil:
  145. s.log.Debugf("exit read loop on error: %s", err.Error())
  146. return
  147. case n >= s.inboundMTU:
  148. s.log.Debugf("Read bytes exceeded MTU, packet is possibly truncated")
  149. }
  150. if err := server.HandleRequest(server.Request{
  151. Conn: p,
  152. SrcAddr: addr,
  153. Buff: buf[:n],
  154. Log: s.log,
  155. AuthHandler: s.authHandler,
  156. Realm: s.realm,
  157. AllocationManager: allocationManager,
  158. ChannelBindTimeout: s.channelBindTimeout,
  159. Nonces: s.nonces,
  160. }); err != nil {
  161. s.log.Errorf("error when handling datagram: %v", err)
  162. }
  163. }
  164. }