123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722 |
- // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
- // SPDX-License-Identifier: MIT
- package stun
- import (
- "crypto/tls"
- "errors"
- "fmt"
- "io"
- "log"
- "net"
- "runtime"
- "strconv"
- "sync"
- "sync/atomic"
- "time"
- "github.com/pion/dtls/v2"
- "github.com/pion/transport/v2"
- "github.com/pion/transport/v2/stdnet"
- )
- // ErrUnsupportedURI is an error thrown if the user passes an unsupported STUN or TURN URI
- var ErrUnsupportedURI = fmt.Errorf("invalid schema or transport")
- // Dial connects to the address on the named network and then
- // initializes Client on that connection, returning error if any.
- func Dial(network, address string) (*Client, error) {
- conn, err := net.Dial(network, address)
- if err != nil {
- return nil, err
- }
- return NewClient(conn)
- }
- // DialConfig is used to pass configuration to DialURI()
- type DialConfig struct {
- DTLSConfig dtls.Config
- TLSConfig tls.Config
- Net transport.Net
- }
- // DialURI connect to the STUN/TURN URI and then
- // initializes Client on that connection, returning error if any.
- func DialURI(uri *URI, cfg *DialConfig) (*Client, error) {
- var conn Connection
- var err error
- nw := cfg.Net
- if nw == nil {
- nw, err = stdnet.NewNet()
- if err != nil {
- return nil, fmt.Errorf("failed to create net: %w", err)
- }
- }
- addr := net.JoinHostPort(uri.Host, strconv.Itoa(uri.Port))
- switch {
- case uri.Scheme == SchemeTypeSTUN:
- if conn, err = nw.Dial("udp", addr); err != nil {
- return nil, fmt.Errorf("failed to listen: %w", err)
- }
- case uri.Scheme == SchemeTypeTURN:
- network := "udp" //nolint:goconst
- if uri.Proto == ProtoTypeTCP {
- network = "tcp" //nolint:goconst
- }
- if conn, err = nw.Dial(network, addr); err != nil {
- return nil, fmt.Errorf("failed to dial: %w", err)
- }
- case uri.Scheme == SchemeTypeTURNS && uri.Proto == ProtoTypeUDP:
- dtlsCfg := cfg.DTLSConfig // Copy
- dtlsCfg.ServerName = uri.Host
- udpConn, err := nw.Dial("udp", addr)
- if err != nil {
- return nil, fmt.Errorf("failed to dial: %w", err)
- }
- if conn, err = dtls.Client(udpConn, &dtlsCfg); err != nil {
- return nil, fmt.Errorf("failed to connect to '%s': %w", addr, err)
- }
- case (uri.Scheme == SchemeTypeTURNS || uri.Scheme == SchemeTypeSTUNS) && uri.Proto == ProtoTypeTCP:
- tlsCfg := cfg.TLSConfig //nolint:govet
- tlsCfg.ServerName = uri.Host
- tcpConn, err := nw.Dial("tcp", addr)
- if err != nil {
- return nil, fmt.Errorf("failed to dial: %w", err)
- }
- conn = tls.Client(tcpConn, &tlsCfg)
- default:
- return nil, ErrUnsupportedURI
- }
- return NewClient(conn)
- }
- // ErrNoConnection means that ClientOptions.Connection is nil.
- var ErrNoConnection = errors.New("no connection provided")
- // ClientOption sets some client option.
- type ClientOption func(c *Client)
- // WithHandler sets client handler which is called if Agent emits the Event
- // with TransactionID that is not currently registered by Client.
- // Useful for handling Data indications from TURN server.
- func WithHandler(h Handler) ClientOption {
- return func(c *Client) {
- c.handler = h
- }
- }
- // WithRTO sets client RTO as defined in STUN RFC.
- func WithRTO(rto time.Duration) ClientOption {
- return func(c *Client) {
- c.rto = int64(rto)
- }
- }
- // WithClock sets Clock of client, the source of current time.
- // Also clock is passed to default collector if set.
- func WithClock(clock Clock) ClientOption {
- return func(c *Client) {
- c.clock = clock
- }
- }
- // WithTimeoutRate sets RTO timer minimum resolution.
- func WithTimeoutRate(d time.Duration) ClientOption {
- return func(c *Client) {
- c.rtoRate = d
- }
- }
- // WithAgent sets client STUN agent.
- //
- // Defaults to agent implementation in current package,
- // see agent.go.
- func WithAgent(a ClientAgent) ClientOption {
- return func(c *Client) {
- c.a = a
- }
- }
- // WithCollector rests client timeout collector, the implementation
- // of ticker which calls function on each tick.
- func WithCollector(coll Collector) ClientOption {
- return func(c *Client) {
- c.collector = coll
- }
- }
- // WithNoConnClose prevents client from closing underlying connection when
- // the Close() method is called.
- func WithNoConnClose() ClientOption {
- return func(c *Client) {
- c.closeConn = false
- }
- }
- // WithNoRetransmit disables retransmissions and sets RTO to
- // defaultMaxAttempts * defaultRTO which will be effectively time out
- // if not set.
- //
- // Useful for TCP connections where transport handles RTO.
- func WithNoRetransmit(c *Client) {
- c.maxAttempts = 0
- if c.rto == 0 {
- c.rto = defaultMaxAttempts * int64(defaultRTO)
- }
- }
- const (
- defaultTimeoutRate = time.Millisecond * 5
- defaultRTO = time.Millisecond * 300
- defaultMaxAttempts = 7
- )
- // NewClient initializes new Client from provided options,
- // starting internal goroutines and using default options fields
- // if necessary. Call Close method after using Client to close conn and
- // release resources.
- //
- // The conn will be closed on Close call. Use WithNoConnClose option to
- // prevent that.
- //
- // Note that user should handle the protocol multiplexing, client does not
- // provide any API for it, so if you need to read application data, wrap the
- // connection with your (de-)multiplexer and pass the wrapper as conn.
- func NewClient(conn Connection, options ...ClientOption) (*Client, error) {
- c := &Client{
- close: make(chan struct{}),
- c: conn,
- clock: systemClock(),
- rto: int64(defaultRTO),
- rtoRate: defaultTimeoutRate,
- t: make(map[transactionID]*clientTransaction, 100),
- maxAttempts: defaultMaxAttempts,
- closeConn: true,
- }
- for _, o := range options {
- o(c)
- }
- if c.c == nil {
- return nil, ErrNoConnection
- }
- if c.a == nil {
- c.a = NewAgent(nil)
- }
- if err := c.a.SetHandler(c.handleAgentCallback); err != nil {
- return nil, err
- }
- if c.collector == nil {
- c.collector = &tickerCollector{
- close: make(chan struct{}),
- clock: c.clock,
- }
- }
- if err := c.collector.Start(c.rtoRate, func(t time.Time) {
- closedOrPanic(c.a.Collect(t))
- }); err != nil {
- return nil, err
- }
- c.wg.Add(1)
- go c.readUntilClosed()
- runtime.SetFinalizer(c, clientFinalizer)
- return c, nil
- }
- func clientFinalizer(c *Client) {
- if c == nil {
- return
- }
- err := c.Close()
- if errors.Is(err, ErrClientClosed) {
- return
- }
- if err == nil {
- log.Println("client: called finalizer on non-closed client") // nolint
- return
- }
- log.Println("client: called finalizer on non-closed client:", err) // nolint
- }
- // Connection wraps Reader, Writer and Closer interfaces.
- type Connection interface {
- io.Reader
- io.Writer
- io.Closer
- }
- // ClientAgent is Agent implementation that is used by Client to
- // process transactions.
- type ClientAgent interface {
- Process(*Message) error
- Close() error
- Start(id [TransactionIDSize]byte, deadline time.Time) error
- Stop(id [TransactionIDSize]byte) error
- Collect(time.Time) error
- SetHandler(h Handler) error
- }
- // Client simulates "connection" to STUN server.
- type Client struct {
- rto int64 // time.Duration
- a ClientAgent
- c Connection
- close chan struct{}
- rtoRate time.Duration
- maxAttempts int32
- closed bool
- closeConn bool // should call c.Close() while closing
- wg sync.WaitGroup
- clock Clock
- handler Handler
- collector Collector
- t map[transactionID]*clientTransaction
- // mux guards closed and t
- mux sync.RWMutex
- }
- // clientTransaction represents transaction in progress.
- // If transaction is succeed or failed, f will be called
- // provided by event.
- // Concurrent access is invalid.
- type clientTransaction struct {
- id transactionID
- attempt int32
- calls int32
- h Handler
- start time.Time
- rto time.Duration
- raw []byte
- }
- func (t *clientTransaction) handle(e Event) {
- if atomic.AddInt32(&t.calls, 1) == 1 {
- t.h(e)
- }
- }
- var clientTransactionPool = &sync.Pool{ //nolint:gochecknoglobals
- New: func() interface{} {
- return &clientTransaction{
- raw: make([]byte, 1500),
- }
- },
- }
- func acquireClientTransaction() *clientTransaction {
- return clientTransactionPool.Get().(*clientTransaction) //nolint:forcetypeassert
- }
- func putClientTransaction(t *clientTransaction) {
- t.raw = t.raw[:0]
- t.start = time.Time{}
- t.attempt = 0
- t.id = transactionID{}
- clientTransactionPool.Put(t)
- }
- func (t *clientTransaction) nextTimeout(now time.Time) time.Time {
- return now.Add(time.Duration(t.attempt+1) * t.rto)
- }
- // start registers transaction.
- //
- // Could return ErrClientClosed, ErrTransactionExists.
- func (c *Client) start(t *clientTransaction) error {
- c.mux.Lock()
- defer c.mux.Unlock()
- if c.closed {
- return ErrClientClosed
- }
- _, exists := c.t[t.id]
- if exists {
- return ErrTransactionExists
- }
- c.t[t.id] = t
- return nil
- }
- // Clock abstracts the source of current time.
- type Clock interface {
- Now() time.Time
- }
- type systemClockService struct{}
- func (systemClockService) Now() time.Time { return time.Now() }
- func systemClock() systemClockService {
- return systemClockService{}
- }
- // SetRTO sets current RTO value.
- func (c *Client) SetRTO(rto time.Duration) {
- atomic.StoreInt64(&c.rto, int64(rto))
- }
- // StopErr occurs when Client fails to stop transaction while
- // processing error.
- //
- //nolint:errname
- type StopErr struct {
- Err error // value returned by Stop()
- Cause error // error that caused Stop() call
- }
- func (e StopErr) Error() string {
- return fmt.Sprintf("error while stopping due to %s: %s", sprintErr(e.Cause), sprintErr(e.Err))
- }
- // CloseErr indicates client close failure.
- //
- //nolint:errname
- type CloseErr struct {
- AgentErr error
- ConnectionErr error
- }
- func sprintErr(err error) string {
- if err == nil {
- return "<nil>" //nolint:goconst
- }
- return err.Error()
- }
- func (c CloseErr) Error() string {
- return fmt.Sprintf("failed to close: %s (connection), %s (agent)", sprintErr(c.ConnectionErr), sprintErr(c.AgentErr))
- }
- func (c *Client) readUntilClosed() {
- defer c.wg.Done()
- m := new(Message)
- m.Raw = make([]byte, 1024)
- for {
- select {
- case <-c.close:
- return
- default:
- }
- _, err := m.ReadFrom(c.c)
- if err == nil {
- if pErr := c.a.Process(m); errors.Is(pErr, ErrAgentClosed) {
- return
- }
- }
- }
- }
- func closedOrPanic(err error) {
- if err == nil || errors.Is(err, ErrAgentClosed) {
- return
- }
- panic(err) //nolint
- }
- type tickerCollector struct {
- close chan struct{}
- wg sync.WaitGroup
- clock Clock
- }
- // Collector calls function f with constant rate.
- //
- // The simple Collector is ticker which calls function on each tick.
- type Collector interface {
- Start(rate time.Duration, f func(now time.Time)) error
- Close() error
- }
- func (a *tickerCollector) Start(rate time.Duration, f func(now time.Time)) error {
- t := time.NewTicker(rate)
- a.wg.Add(1)
- go func() {
- defer a.wg.Done()
- for {
- select {
- case <-a.close:
- t.Stop()
- return
- case <-t.C:
- f(a.clock.Now())
- }
- }
- }()
- return nil
- }
- func (a *tickerCollector) Close() error {
- close(a.close)
- a.wg.Wait()
- return nil
- }
- // ErrClientClosed indicates that client is closed.
- var ErrClientClosed = errors.New("client is closed")
- // Close stops internal connection and agent, returning CloseErr on error.
- func (c *Client) Close() error {
- if err := c.checkInit(); err != nil {
- return err
- }
- c.mux.Lock()
- if c.closed {
- c.mux.Unlock()
- return ErrClientClosed
- }
- c.closed = true
- c.mux.Unlock()
- if closeErr := c.collector.Close(); closeErr != nil {
- return closeErr
- }
- var connErr error
- agentErr := c.a.Close()
- if c.closeConn {
- connErr = c.c.Close()
- }
- close(c.close)
- c.wg.Wait()
- if agentErr == nil && connErr == nil {
- return nil
- }
- return CloseErr{
- AgentErr: agentErr,
- ConnectionErr: connErr,
- }
- }
- // Indicate sends indication m to server. Shorthand to Start call
- // with zero deadline and callback.
- func (c *Client) Indicate(m *Message) error {
- return c.Start(m, nil)
- }
- // callbackWaitHandler blocks on wait() call until callback is called.
- type callbackWaitHandler struct {
- handler Handler
- callback func(event Event)
- cond *sync.Cond
- processed bool
- }
- func (s *callbackWaitHandler) HandleEvent(e Event) {
- s.cond.L.Lock()
- if s.callback == nil {
- panic("s.callback is nil") //nolint
- }
- s.callback(e)
- s.processed = true
- s.cond.Broadcast()
- s.cond.L.Unlock()
- }
- func (s *callbackWaitHandler) wait() {
- s.cond.L.Lock()
- for !s.processed {
- s.cond.Wait()
- }
- s.processed = false
- s.callback = nil
- s.cond.L.Unlock()
- }
- func (s *callbackWaitHandler) setCallback(f func(event Event)) {
- if f == nil {
- panic("f is nil") //nolint
- }
- s.cond.L.Lock()
- s.callback = f
- if s.handler == nil {
- s.handler = s.HandleEvent
- }
- s.cond.L.Unlock()
- }
- var callbackWaitHandlerPool = sync.Pool{ //nolint:gochecknoglobals
- New: func() interface{} {
- return &callbackWaitHandler{
- cond: sync.NewCond(new(sync.Mutex)),
- }
- },
- }
- // ErrClientNotInitialized means that client connection or agent is nil.
- var ErrClientNotInitialized = errors.New("client not initialized")
- func (c *Client) checkInit() error {
- if c == nil || c.c == nil || c.a == nil || c.close == nil {
- return ErrClientNotInitialized
- }
- return nil
- }
- // Do is Start wrapper that waits until callback is called. If no callback
- // provided, Indicate is called instead.
- //
- // Do has cpu overhead due to blocking, see BenchmarkClient_Do.
- // Use Start method for less overhead.
- func (c *Client) Do(m *Message, f func(Event)) error {
- if err := c.checkInit(); err != nil {
- return err
- }
- if f == nil {
- return c.Indicate(m)
- }
- h := callbackWaitHandlerPool.Get().(*callbackWaitHandler) //nolint:forcetypeassert
- h.setCallback(f)
- defer func() {
- callbackWaitHandlerPool.Put(h)
- }()
- if err := c.Start(m, h.handler); err != nil {
- return err
- }
- h.wait()
- return nil
- }
- func (c *Client) delete(id transactionID) {
- c.mux.Lock()
- if c.t != nil {
- delete(c.t, id)
- }
- c.mux.Unlock()
- }
- type buffer struct {
- buf []byte
- }
- var bufferPool = &sync.Pool{ //nolint:gochecknoglobals
- New: func() interface{} {
- return &buffer{buf: make([]byte, 2048)}
- },
- }
- func (c *Client) handleAgentCallback(e Event) {
- c.mux.Lock()
- if c.closed {
- c.mux.Unlock()
- return
- }
- t, found := c.t[e.TransactionID]
- if found {
- delete(c.t, t.id)
- }
- c.mux.Unlock()
- if !found {
- if c.handler != nil && !errors.Is(e.Error, ErrTransactionStopped) {
- c.handler(e)
- }
- // Ignoring.
- return
- }
- if atomic.LoadInt32(&c.maxAttempts) <= t.attempt || e.Error == nil {
- // Transaction completed.
- t.handle(e)
- putClientTransaction(t)
- return
- }
- // Doing re-transmission.
- t.attempt++
- b := bufferPool.Get().(*buffer) //nolint:forcetypeassert
- b.buf = b.buf[:copy(b.buf[:cap(b.buf)], t.raw)]
- defer bufferPool.Put(b)
- var (
- now = c.clock.Now()
- timeOut = t.nextTimeout(now)
- id = t.id
- )
- // Starting client transaction.
- if startErr := c.start(t); startErr != nil {
- c.delete(id)
- e.Error = startErr
- t.handle(e)
- putClientTransaction(t)
- return
- }
- // Starting agent transaction.
- if startErr := c.a.Start(id, timeOut); startErr != nil {
- c.delete(id)
- e.Error = startErr
- t.handle(e)
- putClientTransaction(t)
- return
- }
- // Writing message to connection again.
- _, writeErr := c.c.Write(b.buf)
- if writeErr != nil {
- c.delete(id)
- e.Error = writeErr
- // Stopping agent transaction instead of waiting until it's deadline.
- // This will call handleAgentCallback with "ErrTransactionStopped" error
- // which will be ignored.
- if stopErr := c.a.Stop(id); stopErr != nil {
- // Failed to stop agent transaction. Wrapping the error in StopError.
- e.Error = StopErr{
- Err: stopErr,
- Cause: writeErr,
- }
- }
- t.handle(e)
- putClientTransaction(t)
- return
- }
- }
- // Start starts transaction (if h set) and writes message to server, handler
- // is called asynchronously.
- func (c *Client) Start(m *Message, h Handler) error {
- if err := c.checkInit(); err != nil {
- return err
- }
- c.mux.RLock()
- closed := c.closed
- c.mux.RUnlock()
- if closed {
- return ErrClientClosed
- }
- if h != nil {
- // Starting transaction only if h is set. Useful for indications.
- t := acquireClientTransaction()
- t.id = m.TransactionID
- t.start = c.clock.Now()
- t.h = h
- t.rto = time.Duration(atomic.LoadInt64(&c.rto))
- t.attempt = 0
- t.raw = append(t.raw[:0], m.Raw...)
- t.calls = 0
- d := t.nextTimeout(t.start)
- if err := c.start(t); err != nil {
- return err
- }
- if err := c.a.Start(m.TransactionID, d); err != nil {
- return err
- }
- }
- _, err := m.WriteTo(c.c)
- if err != nil && h != nil {
- c.delete(m.TransactionID)
- // Stopping transaction instead of waiting until deadline.
- if stopErr := c.a.Stop(m.TransactionID); stopErr != nil {
- return StopErr{
- Err: stopErr,
- Cause: err,
- }
- }
- }
- return err
- }
|