client.go 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722
  1. // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
  2. // SPDX-License-Identifier: MIT
  3. package stun
  4. import (
  5. "crypto/tls"
  6. "errors"
  7. "fmt"
  8. "io"
  9. "log"
  10. "net"
  11. "runtime"
  12. "strconv"
  13. "sync"
  14. "sync/atomic"
  15. "time"
  16. "github.com/pion/dtls/v2"
  17. "github.com/pion/transport/v2"
  18. "github.com/pion/transport/v2/stdnet"
  19. )
  20. // ErrUnsupportedURI is an error thrown if the user passes an unsupported STUN or TURN URI
  21. var ErrUnsupportedURI = fmt.Errorf("invalid schema or transport")
  22. // Dial connects to the address on the named network and then
  23. // initializes Client on that connection, returning error if any.
  24. func Dial(network, address string) (*Client, error) {
  25. conn, err := net.Dial(network, address)
  26. if err != nil {
  27. return nil, err
  28. }
  29. return NewClient(conn)
  30. }
  31. // DialConfig is used to pass configuration to DialURI()
  32. type DialConfig struct {
  33. DTLSConfig dtls.Config
  34. TLSConfig tls.Config
  35. Net transport.Net
  36. }
  37. // DialURI connect to the STUN/TURN URI and then
  38. // initializes Client on that connection, returning error if any.
  39. func DialURI(uri *URI, cfg *DialConfig) (*Client, error) {
  40. var conn Connection
  41. var err error
  42. nw := cfg.Net
  43. if nw == nil {
  44. nw, err = stdnet.NewNet()
  45. if err != nil {
  46. return nil, fmt.Errorf("failed to create net: %w", err)
  47. }
  48. }
  49. addr := net.JoinHostPort(uri.Host, strconv.Itoa(uri.Port))
  50. switch {
  51. case uri.Scheme == SchemeTypeSTUN:
  52. if conn, err = nw.Dial("udp", addr); err != nil {
  53. return nil, fmt.Errorf("failed to listen: %w", err)
  54. }
  55. case uri.Scheme == SchemeTypeTURN:
  56. network := "udp" //nolint:goconst
  57. if uri.Proto == ProtoTypeTCP {
  58. network = "tcp" //nolint:goconst
  59. }
  60. if conn, err = nw.Dial(network, addr); err != nil {
  61. return nil, fmt.Errorf("failed to dial: %w", err)
  62. }
  63. case uri.Scheme == SchemeTypeTURNS && uri.Proto == ProtoTypeUDP:
  64. dtlsCfg := cfg.DTLSConfig // Copy
  65. dtlsCfg.ServerName = uri.Host
  66. udpConn, err := nw.Dial("udp", addr)
  67. if err != nil {
  68. return nil, fmt.Errorf("failed to dial: %w", err)
  69. }
  70. if conn, err = dtls.Client(udpConn, &dtlsCfg); err != nil {
  71. return nil, fmt.Errorf("failed to connect to '%s': %w", addr, err)
  72. }
  73. case (uri.Scheme == SchemeTypeTURNS || uri.Scheme == SchemeTypeSTUNS) && uri.Proto == ProtoTypeTCP:
  74. tlsCfg := cfg.TLSConfig //nolint:govet
  75. tlsCfg.ServerName = uri.Host
  76. tcpConn, err := nw.Dial("tcp", addr)
  77. if err != nil {
  78. return nil, fmt.Errorf("failed to dial: %w", err)
  79. }
  80. conn = tls.Client(tcpConn, &tlsCfg)
  81. default:
  82. return nil, ErrUnsupportedURI
  83. }
  84. return NewClient(conn)
  85. }
  86. // ErrNoConnection means that ClientOptions.Connection is nil.
  87. var ErrNoConnection = errors.New("no connection provided")
  88. // ClientOption sets some client option.
  89. type ClientOption func(c *Client)
  90. // WithHandler sets client handler which is called if Agent emits the Event
  91. // with TransactionID that is not currently registered by Client.
  92. // Useful for handling Data indications from TURN server.
  93. func WithHandler(h Handler) ClientOption {
  94. return func(c *Client) {
  95. c.handler = h
  96. }
  97. }
  98. // WithRTO sets client RTO as defined in STUN RFC.
  99. func WithRTO(rto time.Duration) ClientOption {
  100. return func(c *Client) {
  101. c.rto = int64(rto)
  102. }
  103. }
  104. // WithClock sets Clock of client, the source of current time.
  105. // Also clock is passed to default collector if set.
  106. func WithClock(clock Clock) ClientOption {
  107. return func(c *Client) {
  108. c.clock = clock
  109. }
  110. }
  111. // WithTimeoutRate sets RTO timer minimum resolution.
  112. func WithTimeoutRate(d time.Duration) ClientOption {
  113. return func(c *Client) {
  114. c.rtoRate = d
  115. }
  116. }
  117. // WithAgent sets client STUN agent.
  118. //
  119. // Defaults to agent implementation in current package,
  120. // see agent.go.
  121. func WithAgent(a ClientAgent) ClientOption {
  122. return func(c *Client) {
  123. c.a = a
  124. }
  125. }
  126. // WithCollector rests client timeout collector, the implementation
  127. // of ticker which calls function on each tick.
  128. func WithCollector(coll Collector) ClientOption {
  129. return func(c *Client) {
  130. c.collector = coll
  131. }
  132. }
  133. // WithNoConnClose prevents client from closing underlying connection when
  134. // the Close() method is called.
  135. func WithNoConnClose() ClientOption {
  136. return func(c *Client) {
  137. c.closeConn = false
  138. }
  139. }
  140. // WithNoRetransmit disables retransmissions and sets RTO to
  141. // defaultMaxAttempts * defaultRTO which will be effectively time out
  142. // if not set.
  143. //
  144. // Useful for TCP connections where transport handles RTO.
  145. func WithNoRetransmit(c *Client) {
  146. c.maxAttempts = 0
  147. if c.rto == 0 {
  148. c.rto = defaultMaxAttempts * int64(defaultRTO)
  149. }
  150. }
  151. const (
  152. defaultTimeoutRate = time.Millisecond * 5
  153. defaultRTO = time.Millisecond * 300
  154. defaultMaxAttempts = 7
  155. )
  156. // NewClient initializes new Client from provided options,
  157. // starting internal goroutines and using default options fields
  158. // if necessary. Call Close method after using Client to close conn and
  159. // release resources.
  160. //
  161. // The conn will be closed on Close call. Use WithNoConnClose option to
  162. // prevent that.
  163. //
  164. // Note that user should handle the protocol multiplexing, client does not
  165. // provide any API for it, so if you need to read application data, wrap the
  166. // connection with your (de-)multiplexer and pass the wrapper as conn.
  167. func NewClient(conn Connection, options ...ClientOption) (*Client, error) {
  168. c := &Client{
  169. close: make(chan struct{}),
  170. c: conn,
  171. clock: systemClock(),
  172. rto: int64(defaultRTO),
  173. rtoRate: defaultTimeoutRate,
  174. t: make(map[transactionID]*clientTransaction, 100),
  175. maxAttempts: defaultMaxAttempts,
  176. closeConn: true,
  177. }
  178. for _, o := range options {
  179. o(c)
  180. }
  181. if c.c == nil {
  182. return nil, ErrNoConnection
  183. }
  184. if c.a == nil {
  185. c.a = NewAgent(nil)
  186. }
  187. if err := c.a.SetHandler(c.handleAgentCallback); err != nil {
  188. return nil, err
  189. }
  190. if c.collector == nil {
  191. c.collector = &tickerCollector{
  192. close: make(chan struct{}),
  193. clock: c.clock,
  194. }
  195. }
  196. if err := c.collector.Start(c.rtoRate, func(t time.Time) {
  197. closedOrPanic(c.a.Collect(t))
  198. }); err != nil {
  199. return nil, err
  200. }
  201. c.wg.Add(1)
  202. go c.readUntilClosed()
  203. runtime.SetFinalizer(c, clientFinalizer)
  204. return c, nil
  205. }
  206. func clientFinalizer(c *Client) {
  207. if c == nil {
  208. return
  209. }
  210. err := c.Close()
  211. if errors.Is(err, ErrClientClosed) {
  212. return
  213. }
  214. if err == nil {
  215. log.Println("client: called finalizer on non-closed client") // nolint
  216. return
  217. }
  218. log.Println("client: called finalizer on non-closed client:", err) // nolint
  219. }
  220. // Connection wraps Reader, Writer and Closer interfaces.
  221. type Connection interface {
  222. io.Reader
  223. io.Writer
  224. io.Closer
  225. }
  226. // ClientAgent is Agent implementation that is used by Client to
  227. // process transactions.
  228. type ClientAgent interface {
  229. Process(*Message) error
  230. Close() error
  231. Start(id [TransactionIDSize]byte, deadline time.Time) error
  232. Stop(id [TransactionIDSize]byte) error
  233. Collect(time.Time) error
  234. SetHandler(h Handler) error
  235. }
  236. // Client simulates "connection" to STUN server.
  237. type Client struct {
  238. rto int64 // time.Duration
  239. a ClientAgent
  240. c Connection
  241. close chan struct{}
  242. rtoRate time.Duration
  243. maxAttempts int32
  244. closed bool
  245. closeConn bool // should call c.Close() while closing
  246. wg sync.WaitGroup
  247. clock Clock
  248. handler Handler
  249. collector Collector
  250. t map[transactionID]*clientTransaction
  251. // mux guards closed and t
  252. mux sync.RWMutex
  253. }
  254. // clientTransaction represents transaction in progress.
  255. // If transaction is succeed or failed, f will be called
  256. // provided by event.
  257. // Concurrent access is invalid.
  258. type clientTransaction struct {
  259. id transactionID
  260. attempt int32
  261. calls int32
  262. h Handler
  263. start time.Time
  264. rto time.Duration
  265. raw []byte
  266. }
  267. func (t *clientTransaction) handle(e Event) {
  268. if atomic.AddInt32(&t.calls, 1) == 1 {
  269. t.h(e)
  270. }
  271. }
  272. var clientTransactionPool = &sync.Pool{ //nolint:gochecknoglobals
  273. New: func() interface{} {
  274. return &clientTransaction{
  275. raw: make([]byte, 1500),
  276. }
  277. },
  278. }
  279. func acquireClientTransaction() *clientTransaction {
  280. return clientTransactionPool.Get().(*clientTransaction) //nolint:forcetypeassert
  281. }
  282. func putClientTransaction(t *clientTransaction) {
  283. t.raw = t.raw[:0]
  284. t.start = time.Time{}
  285. t.attempt = 0
  286. t.id = transactionID{}
  287. clientTransactionPool.Put(t)
  288. }
  289. func (t *clientTransaction) nextTimeout(now time.Time) time.Time {
  290. return now.Add(time.Duration(t.attempt+1) * t.rto)
  291. }
  292. // start registers transaction.
  293. //
  294. // Could return ErrClientClosed, ErrTransactionExists.
  295. func (c *Client) start(t *clientTransaction) error {
  296. c.mux.Lock()
  297. defer c.mux.Unlock()
  298. if c.closed {
  299. return ErrClientClosed
  300. }
  301. _, exists := c.t[t.id]
  302. if exists {
  303. return ErrTransactionExists
  304. }
  305. c.t[t.id] = t
  306. return nil
  307. }
  308. // Clock abstracts the source of current time.
  309. type Clock interface {
  310. Now() time.Time
  311. }
  312. type systemClockService struct{}
  313. func (systemClockService) Now() time.Time { return time.Now() }
  314. func systemClock() systemClockService {
  315. return systemClockService{}
  316. }
  317. // SetRTO sets current RTO value.
  318. func (c *Client) SetRTO(rto time.Duration) {
  319. atomic.StoreInt64(&c.rto, int64(rto))
  320. }
  321. // StopErr occurs when Client fails to stop transaction while
  322. // processing error.
  323. //
  324. //nolint:errname
  325. type StopErr struct {
  326. Err error // value returned by Stop()
  327. Cause error // error that caused Stop() call
  328. }
  329. func (e StopErr) Error() string {
  330. return fmt.Sprintf("error while stopping due to %s: %s", sprintErr(e.Cause), sprintErr(e.Err))
  331. }
  332. // CloseErr indicates client close failure.
  333. //
  334. //nolint:errname
  335. type CloseErr struct {
  336. AgentErr error
  337. ConnectionErr error
  338. }
  339. func sprintErr(err error) string {
  340. if err == nil {
  341. return "<nil>" //nolint:goconst
  342. }
  343. return err.Error()
  344. }
  345. func (c CloseErr) Error() string {
  346. return fmt.Sprintf("failed to close: %s (connection), %s (agent)", sprintErr(c.ConnectionErr), sprintErr(c.AgentErr))
  347. }
  348. func (c *Client) readUntilClosed() {
  349. defer c.wg.Done()
  350. m := new(Message)
  351. m.Raw = make([]byte, 1024)
  352. for {
  353. select {
  354. case <-c.close:
  355. return
  356. default:
  357. }
  358. _, err := m.ReadFrom(c.c)
  359. if err == nil {
  360. if pErr := c.a.Process(m); errors.Is(pErr, ErrAgentClosed) {
  361. return
  362. }
  363. }
  364. }
  365. }
  366. func closedOrPanic(err error) {
  367. if err == nil || errors.Is(err, ErrAgentClosed) {
  368. return
  369. }
  370. panic(err) //nolint
  371. }
  372. type tickerCollector struct {
  373. close chan struct{}
  374. wg sync.WaitGroup
  375. clock Clock
  376. }
  377. // Collector calls function f with constant rate.
  378. //
  379. // The simple Collector is ticker which calls function on each tick.
  380. type Collector interface {
  381. Start(rate time.Duration, f func(now time.Time)) error
  382. Close() error
  383. }
  384. func (a *tickerCollector) Start(rate time.Duration, f func(now time.Time)) error {
  385. t := time.NewTicker(rate)
  386. a.wg.Add(1)
  387. go func() {
  388. defer a.wg.Done()
  389. for {
  390. select {
  391. case <-a.close:
  392. t.Stop()
  393. return
  394. case <-t.C:
  395. f(a.clock.Now())
  396. }
  397. }
  398. }()
  399. return nil
  400. }
  401. func (a *tickerCollector) Close() error {
  402. close(a.close)
  403. a.wg.Wait()
  404. return nil
  405. }
  406. // ErrClientClosed indicates that client is closed.
  407. var ErrClientClosed = errors.New("client is closed")
  408. // Close stops internal connection and agent, returning CloseErr on error.
  409. func (c *Client) Close() error {
  410. if err := c.checkInit(); err != nil {
  411. return err
  412. }
  413. c.mux.Lock()
  414. if c.closed {
  415. c.mux.Unlock()
  416. return ErrClientClosed
  417. }
  418. c.closed = true
  419. c.mux.Unlock()
  420. if closeErr := c.collector.Close(); closeErr != nil {
  421. return closeErr
  422. }
  423. var connErr error
  424. agentErr := c.a.Close()
  425. if c.closeConn {
  426. connErr = c.c.Close()
  427. }
  428. close(c.close)
  429. c.wg.Wait()
  430. if agentErr == nil && connErr == nil {
  431. return nil
  432. }
  433. return CloseErr{
  434. AgentErr: agentErr,
  435. ConnectionErr: connErr,
  436. }
  437. }
  438. // Indicate sends indication m to server. Shorthand to Start call
  439. // with zero deadline and callback.
  440. func (c *Client) Indicate(m *Message) error {
  441. return c.Start(m, nil)
  442. }
  443. // callbackWaitHandler blocks on wait() call until callback is called.
  444. type callbackWaitHandler struct {
  445. handler Handler
  446. callback func(event Event)
  447. cond *sync.Cond
  448. processed bool
  449. }
  450. func (s *callbackWaitHandler) HandleEvent(e Event) {
  451. s.cond.L.Lock()
  452. if s.callback == nil {
  453. panic("s.callback is nil") //nolint
  454. }
  455. s.callback(e)
  456. s.processed = true
  457. s.cond.Broadcast()
  458. s.cond.L.Unlock()
  459. }
  460. func (s *callbackWaitHandler) wait() {
  461. s.cond.L.Lock()
  462. for !s.processed {
  463. s.cond.Wait()
  464. }
  465. s.processed = false
  466. s.callback = nil
  467. s.cond.L.Unlock()
  468. }
  469. func (s *callbackWaitHandler) setCallback(f func(event Event)) {
  470. if f == nil {
  471. panic("f is nil") //nolint
  472. }
  473. s.cond.L.Lock()
  474. s.callback = f
  475. if s.handler == nil {
  476. s.handler = s.HandleEvent
  477. }
  478. s.cond.L.Unlock()
  479. }
  480. var callbackWaitHandlerPool = sync.Pool{ //nolint:gochecknoglobals
  481. New: func() interface{} {
  482. return &callbackWaitHandler{
  483. cond: sync.NewCond(new(sync.Mutex)),
  484. }
  485. },
  486. }
  487. // ErrClientNotInitialized means that client connection or agent is nil.
  488. var ErrClientNotInitialized = errors.New("client not initialized")
  489. func (c *Client) checkInit() error {
  490. if c == nil || c.c == nil || c.a == nil || c.close == nil {
  491. return ErrClientNotInitialized
  492. }
  493. return nil
  494. }
  495. // Do is Start wrapper that waits until callback is called. If no callback
  496. // provided, Indicate is called instead.
  497. //
  498. // Do has cpu overhead due to blocking, see BenchmarkClient_Do.
  499. // Use Start method for less overhead.
  500. func (c *Client) Do(m *Message, f func(Event)) error {
  501. if err := c.checkInit(); err != nil {
  502. return err
  503. }
  504. if f == nil {
  505. return c.Indicate(m)
  506. }
  507. h := callbackWaitHandlerPool.Get().(*callbackWaitHandler) //nolint:forcetypeassert
  508. h.setCallback(f)
  509. defer func() {
  510. callbackWaitHandlerPool.Put(h)
  511. }()
  512. if err := c.Start(m, h.handler); err != nil {
  513. return err
  514. }
  515. h.wait()
  516. return nil
  517. }
  518. func (c *Client) delete(id transactionID) {
  519. c.mux.Lock()
  520. if c.t != nil {
  521. delete(c.t, id)
  522. }
  523. c.mux.Unlock()
  524. }
  525. type buffer struct {
  526. buf []byte
  527. }
  528. var bufferPool = &sync.Pool{ //nolint:gochecknoglobals
  529. New: func() interface{} {
  530. return &buffer{buf: make([]byte, 2048)}
  531. },
  532. }
  533. func (c *Client) handleAgentCallback(e Event) {
  534. c.mux.Lock()
  535. if c.closed {
  536. c.mux.Unlock()
  537. return
  538. }
  539. t, found := c.t[e.TransactionID]
  540. if found {
  541. delete(c.t, t.id)
  542. }
  543. c.mux.Unlock()
  544. if !found {
  545. if c.handler != nil && !errors.Is(e.Error, ErrTransactionStopped) {
  546. c.handler(e)
  547. }
  548. // Ignoring.
  549. return
  550. }
  551. if atomic.LoadInt32(&c.maxAttempts) <= t.attempt || e.Error == nil {
  552. // Transaction completed.
  553. t.handle(e)
  554. putClientTransaction(t)
  555. return
  556. }
  557. // Doing re-transmission.
  558. t.attempt++
  559. b := bufferPool.Get().(*buffer) //nolint:forcetypeassert
  560. b.buf = b.buf[:copy(b.buf[:cap(b.buf)], t.raw)]
  561. defer bufferPool.Put(b)
  562. var (
  563. now = c.clock.Now()
  564. timeOut = t.nextTimeout(now)
  565. id = t.id
  566. )
  567. // Starting client transaction.
  568. if startErr := c.start(t); startErr != nil {
  569. c.delete(id)
  570. e.Error = startErr
  571. t.handle(e)
  572. putClientTransaction(t)
  573. return
  574. }
  575. // Starting agent transaction.
  576. if startErr := c.a.Start(id, timeOut); startErr != nil {
  577. c.delete(id)
  578. e.Error = startErr
  579. t.handle(e)
  580. putClientTransaction(t)
  581. return
  582. }
  583. // Writing message to connection again.
  584. _, writeErr := c.c.Write(b.buf)
  585. if writeErr != nil {
  586. c.delete(id)
  587. e.Error = writeErr
  588. // Stopping agent transaction instead of waiting until it's deadline.
  589. // This will call handleAgentCallback with "ErrTransactionStopped" error
  590. // which will be ignored.
  591. if stopErr := c.a.Stop(id); stopErr != nil {
  592. // Failed to stop agent transaction. Wrapping the error in StopError.
  593. e.Error = StopErr{
  594. Err: stopErr,
  595. Cause: writeErr,
  596. }
  597. }
  598. t.handle(e)
  599. putClientTransaction(t)
  600. return
  601. }
  602. }
  603. // Start starts transaction (if h set) and writes message to server, handler
  604. // is called asynchronously.
  605. func (c *Client) Start(m *Message, h Handler) error {
  606. if err := c.checkInit(); err != nil {
  607. return err
  608. }
  609. c.mux.RLock()
  610. closed := c.closed
  611. c.mux.RUnlock()
  612. if closed {
  613. return ErrClientClosed
  614. }
  615. if h != nil {
  616. // Starting transaction only if h is set. Useful for indications.
  617. t := acquireClientTransaction()
  618. t.id = m.TransactionID
  619. t.start = c.clock.Now()
  620. t.h = h
  621. t.rto = time.Duration(atomic.LoadInt64(&c.rto))
  622. t.attempt = 0
  623. t.raw = append(t.raw[:0], m.Raw...)
  624. t.calls = 0
  625. d := t.nextTimeout(t.start)
  626. if err := c.start(t); err != nil {
  627. return err
  628. }
  629. if err := c.a.Start(m.TransactionID, d); err != nil {
  630. return err
  631. }
  632. }
  633. _, err := m.WriteTo(c.c)
  634. if err != nil && h != nil {
  635. c.delete(m.TransactionID)
  636. // Stopping transaction instead of waiting until deadline.
  637. if stopErr := c.a.Stop(m.TransactionID); stopErr != nil {
  638. return StopErr{
  639. Err: stopErr,
  640. Cause: err,
  641. }
  642. }
  643. }
  644. return err
  645. }