message.go 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622
  1. // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
  2. // SPDX-License-Identifier: MIT
  3. package stun
  4. import (
  5. "crypto/rand"
  6. "encoding/base64"
  7. "errors"
  8. "fmt"
  9. "io"
  10. )
  11. const (
  12. // magicCookie is fixed value that aids in distinguishing STUN packets
  13. // from packets of other protocols when STUN is multiplexed with those
  14. // other protocols on the same Port.
  15. //
  16. // The magic cookie field MUST contain the fixed value 0x2112A442 in
  17. // network byte order.
  18. //
  19. // Defined in "STUN Message Structure", section 6.
  20. magicCookie = 0x2112A442
  21. attributeHeaderSize = 4
  22. messageHeaderSize = 20
  23. // TransactionIDSize is length of transaction id array (in bytes).
  24. TransactionIDSize = 12 // 96 bit
  25. )
  26. // NewTransactionID returns new random transaction ID using crypto/rand
  27. // as source.
  28. func NewTransactionID() (b [TransactionIDSize]byte) {
  29. readFullOrPanic(rand.Reader, b[:])
  30. return b
  31. }
  32. // IsMessage returns true if b looks like STUN message.
  33. // Useful for multiplexing. IsMessage does not guarantee
  34. // that decoding will be successful.
  35. func IsMessage(b []byte) bool {
  36. return len(b) >= messageHeaderSize && bin.Uint32(b[4:8]) == magicCookie
  37. }
  38. // New returns *Message with pre-allocated Raw.
  39. func New() *Message {
  40. const defaultRawCapacity = 120
  41. return &Message{
  42. Raw: make([]byte, messageHeaderSize, defaultRawCapacity),
  43. }
  44. }
  45. // ErrDecodeToNil occurs on Decode(data, nil) call.
  46. var ErrDecodeToNil = errors.New("attempt to decode to nil message")
  47. // Decode decodes Message from data to m, returning error if any.
  48. func Decode(data []byte, m *Message) error {
  49. if m == nil {
  50. return ErrDecodeToNil
  51. }
  52. m.Raw = append(m.Raw[:0], data...)
  53. return m.Decode()
  54. }
  55. // Message represents a single STUN packet. It uses aggressive internal
  56. // buffering to enable zero-allocation encoding and decoding,
  57. // so there are some usage constraints:
  58. //
  59. // Message, its fields, results of m.Get or any attribute a.GetFrom
  60. // are valid only until Message.Raw is not modified.
  61. type Message struct {
  62. Type MessageType
  63. Length uint32 // len(Raw) not including header
  64. TransactionID [TransactionIDSize]byte
  65. Attributes Attributes
  66. Raw []byte
  67. }
  68. // MarshalBinary implements the encoding.BinaryMarshaler interface.
  69. func (m Message) MarshalBinary() (data []byte, err error) {
  70. // We can't return m.Raw, allocation is expected by implicit interface
  71. // contract induced by other implementations.
  72. b := make([]byte, len(m.Raw))
  73. copy(b, m.Raw)
  74. return b, nil
  75. }
  76. // UnmarshalBinary implements the encoding.BinaryUnmarshaler interface.
  77. func (m *Message) UnmarshalBinary(data []byte) error {
  78. // We can't retain data, copy is expected by interface contract.
  79. m.Raw = append(m.Raw[:0], data...)
  80. return m.Decode()
  81. }
  82. // GobEncode implements the gob.GobEncoder interface.
  83. func (m Message) GobEncode() ([]byte, error) {
  84. return m.MarshalBinary()
  85. }
  86. // GobDecode implements the gob.GobDecoder interface.
  87. func (m *Message) GobDecode(data []byte) error {
  88. return m.UnmarshalBinary(data)
  89. }
  90. // AddTo sets b.TransactionID to m.TransactionID.
  91. //
  92. // Implements Setter to aid in crafting responses.
  93. func (m *Message) AddTo(b *Message) error {
  94. b.TransactionID = m.TransactionID
  95. b.WriteTransactionID()
  96. return nil
  97. }
  98. // NewTransactionID sets m.TransactionID to random value from crypto/rand
  99. // and returns error if any.
  100. func (m *Message) NewTransactionID() error {
  101. _, err := io.ReadFull(rand.Reader, m.TransactionID[:])
  102. if err == nil {
  103. m.WriteTransactionID()
  104. }
  105. return err
  106. }
  107. func (m *Message) String() string {
  108. tID := base64.StdEncoding.EncodeToString(m.TransactionID[:])
  109. aInfo := ""
  110. for k, a := range m.Attributes {
  111. aInfo += fmt.Sprintf("attr%d=%s ", k, a.Type)
  112. }
  113. return fmt.Sprintf("%s l=%d attrs=%d id=%s, %s", m.Type, m.Length, len(m.Attributes), tID, aInfo)
  114. }
  115. // Reset resets Message, attributes and underlying buffer length.
  116. func (m *Message) Reset() {
  117. m.Raw = m.Raw[:0]
  118. m.Length = 0
  119. m.Attributes = m.Attributes[:0]
  120. }
  121. // grow ensures that internal buffer has n length.
  122. func (m *Message) grow(n int) {
  123. if len(m.Raw) >= n {
  124. return
  125. }
  126. if cap(m.Raw) >= n {
  127. m.Raw = m.Raw[:n]
  128. return
  129. }
  130. m.Raw = append(m.Raw, make([]byte, n-len(m.Raw))...)
  131. }
  132. // Add appends new attribute to message. Not goroutine-safe.
  133. //
  134. // Value of attribute is copied to internal buffer so
  135. // it is safe to reuse v.
  136. func (m *Message) Add(t AttrType, v []byte) {
  137. // Allocating buffer for TLV (type-length-value).
  138. // T = t, L = len(v), V = v.
  139. // m.Raw will look like:
  140. // [0:20] <- message header
  141. // [20:20+m.Length] <- existing message attributes
  142. // [20+m.Length:20+m.Length+len(v) + 4] <- allocated buffer for new TLV
  143. // [first:last] <- same as previous
  144. // [0 1|2 3|4 4 + len(v)] <- mapping for allocated buffer
  145. // T L V
  146. allocSize := attributeHeaderSize + len(v) // ~ len(TLV) = len(TL) + len(V)
  147. first := messageHeaderSize + int(m.Length) // first byte number
  148. last := first + allocSize // last byte number
  149. m.grow(last) // growing cap(Raw) to fit TLV
  150. m.Raw = m.Raw[:last] // now len(Raw) = last
  151. m.Length += uint32(allocSize) // rendering length change
  152. // Sub-slicing internal buffer to simplify encoding.
  153. buf := m.Raw[first:last] // slice for TLV
  154. value := buf[attributeHeaderSize:] // slice for V
  155. attr := RawAttribute{
  156. Type: t, // T
  157. Length: uint16(len(v)), // L
  158. Value: value, // V
  159. }
  160. // Encoding attribute TLV to allocated buffer.
  161. bin.PutUint16(buf[0:2], attr.Type.Value()) // T
  162. bin.PutUint16(buf[2:4], attr.Length) // L
  163. copy(value, v) // V
  164. // Checking that attribute value needs padding.
  165. if attr.Length%padding != 0 {
  166. // Performing padding.
  167. bytesToAdd := nearestPaddedValueLength(len(v)) - len(v)
  168. last += bytesToAdd
  169. m.grow(last)
  170. // setting all padding bytes to zero
  171. // to prevent data leak from previous
  172. // data in next bytesToAdd bytes
  173. buf = m.Raw[last-bytesToAdd : last]
  174. for i := range buf {
  175. buf[i] = 0
  176. }
  177. m.Raw = m.Raw[:last] // increasing buffer length
  178. m.Length += uint32(bytesToAdd) // rendering length change
  179. }
  180. m.Attributes = append(m.Attributes, attr)
  181. m.WriteLength()
  182. }
  183. func attrSliceEqual(a, b Attributes) bool {
  184. for _, attr := range a {
  185. found := false
  186. for _, attrB := range b {
  187. if attrB.Type != attr.Type {
  188. continue
  189. }
  190. if attrB.Equal(attr) {
  191. found = true
  192. break
  193. }
  194. }
  195. if !found {
  196. return false
  197. }
  198. }
  199. return true
  200. }
  201. func attrEqual(a, b Attributes) bool {
  202. if a == nil && b == nil {
  203. return true
  204. }
  205. if a == nil || b == nil {
  206. return false
  207. }
  208. if len(a) != len(b) {
  209. return false
  210. }
  211. if !attrSliceEqual(a, b) {
  212. return false
  213. }
  214. if !attrSliceEqual(b, a) {
  215. return false
  216. }
  217. return true
  218. }
  219. // Equal returns true if Message b equals to m.
  220. // Ignores m.Raw.
  221. func (m *Message) Equal(b *Message) bool {
  222. if m == nil && b == nil {
  223. return true
  224. }
  225. if m == nil || b == nil {
  226. return false
  227. }
  228. if m.Type != b.Type {
  229. return false
  230. }
  231. if m.TransactionID != b.TransactionID {
  232. return false
  233. }
  234. if m.Length != b.Length {
  235. return false
  236. }
  237. if !attrEqual(m.Attributes, b.Attributes) {
  238. return false
  239. }
  240. return true
  241. }
  242. // WriteLength writes m.Length to m.Raw.
  243. func (m *Message) WriteLength() {
  244. m.grow(4)
  245. bin.PutUint16(m.Raw[2:4], uint16(m.Length))
  246. }
  247. // WriteHeader writes header to underlying buffer. Not goroutine-safe.
  248. func (m *Message) WriteHeader() {
  249. m.grow(messageHeaderSize)
  250. _ = m.Raw[:messageHeaderSize] // early bounds check to guarantee safety of writes below
  251. m.WriteType()
  252. m.WriteLength()
  253. bin.PutUint32(m.Raw[4:8], magicCookie) // magic cookie
  254. copy(m.Raw[8:messageHeaderSize], m.TransactionID[:]) // transaction ID
  255. }
  256. // WriteTransactionID writes m.TransactionID to m.Raw.
  257. func (m *Message) WriteTransactionID() {
  258. copy(m.Raw[8:messageHeaderSize], m.TransactionID[:]) // transaction ID
  259. }
  260. // WriteAttributes encodes all m.Attributes to m.
  261. func (m *Message) WriteAttributes() {
  262. attributes := m.Attributes
  263. m.Attributes = attributes[:0]
  264. for _, a := range attributes {
  265. m.Add(a.Type, a.Value)
  266. }
  267. m.Attributes = attributes
  268. }
  269. // WriteType writes m.Type to m.Raw.
  270. func (m *Message) WriteType() {
  271. m.grow(2)
  272. bin.PutUint16(m.Raw[0:2], m.Type.Value()) // message type
  273. }
  274. // SetType sets m.Type and writes it to m.Raw.
  275. func (m *Message) SetType(t MessageType) {
  276. m.Type = t
  277. m.WriteType()
  278. }
  279. // Encode re-encodes message into m.Raw.
  280. func (m *Message) Encode() {
  281. m.Raw = m.Raw[:0]
  282. m.WriteHeader()
  283. m.Length = 0
  284. m.WriteAttributes()
  285. }
  286. // WriteTo implements WriterTo via calling Write(m.Raw) on w and returning
  287. // call result.
  288. func (m *Message) WriteTo(w io.Writer) (int64, error) {
  289. n, err := w.Write(m.Raw)
  290. return int64(n), err
  291. }
  292. // ReadFrom implements ReaderFrom. Reads message from r into m.Raw,
  293. // Decodes it and return error if any. If m.Raw is too small, will return
  294. // ErrUnexpectedEOF, ErrUnexpectedHeaderEOF or *DecodeErr.
  295. //
  296. // Can return *DecodeErr while decoding too.
  297. func (m *Message) ReadFrom(r io.Reader) (int64, error) {
  298. tBuf := m.Raw[:cap(m.Raw)]
  299. var (
  300. n int
  301. err error
  302. )
  303. if n, err = r.Read(tBuf); err != nil {
  304. return int64(n), err
  305. }
  306. m.Raw = tBuf[:n]
  307. return int64(n), m.Decode()
  308. }
  309. // ErrUnexpectedHeaderEOF means that there were not enough bytes in
  310. // m.Raw to read header.
  311. var ErrUnexpectedHeaderEOF = errors.New("unexpected EOF: not enough bytes to read header")
  312. // Decode decodes m.Raw into m.
  313. func (m *Message) Decode() error {
  314. // decoding message header
  315. buf := m.Raw
  316. if len(buf) < messageHeaderSize {
  317. return ErrUnexpectedHeaderEOF
  318. }
  319. var (
  320. t = bin.Uint16(buf[0:2]) // first 2 bytes
  321. size = int(bin.Uint16(buf[2:4])) // second 2 bytes
  322. cookie = bin.Uint32(buf[4:8]) // last 4 bytes
  323. fullSize = messageHeaderSize + size // len(m.Raw)
  324. )
  325. if cookie != magicCookie {
  326. msg := fmt.Sprintf("%x is invalid magic cookie (should be %x)", cookie, magicCookie)
  327. return newDecodeErr("message", "cookie", msg)
  328. }
  329. if len(buf) < fullSize {
  330. msg := fmt.Sprintf("buffer length %d is less than %d (expected message size)", len(buf), fullSize)
  331. return newAttrDecodeErr("message", msg)
  332. }
  333. // saving header data
  334. m.Type.ReadValue(t)
  335. m.Length = uint32(size)
  336. copy(m.TransactionID[:], buf[8:messageHeaderSize])
  337. m.Attributes = m.Attributes[:0]
  338. var (
  339. offset = 0
  340. b = buf[messageHeaderSize:fullSize]
  341. )
  342. for offset < size {
  343. // checking that we have enough bytes to read header
  344. if len(b) < attributeHeaderSize {
  345. msg := fmt.Sprintf("buffer length %d is less than %d (expected header size)", len(b), attributeHeaderSize)
  346. return newAttrDecodeErr("header", msg)
  347. }
  348. var (
  349. a = RawAttribute{
  350. Type: compatAttrType(bin.Uint16(b[0:2])), // first 2 bytes
  351. Length: bin.Uint16(b[2:4]), // second 2 bytes
  352. }
  353. aL = int(a.Length) // attribute length
  354. aBuffL = nearestPaddedValueLength(aL) // expected buffer length (with padding)
  355. )
  356. b = b[attributeHeaderSize:] // slicing again to simplify value read
  357. offset += attributeHeaderSize
  358. if len(b) < aBuffL { // checking size
  359. msg := fmt.Sprintf("buffer length %d is less than %d (expected value size for %s)", len(b), aBuffL, a.Type)
  360. return newAttrDecodeErr("value", msg)
  361. }
  362. a.Value = b[:aL]
  363. offset += aBuffL
  364. b = b[aBuffL:]
  365. m.Attributes = append(m.Attributes, a)
  366. }
  367. return nil
  368. }
  369. // Write decodes message and return error if any.
  370. //
  371. // Any error is unrecoverable, but message could be partially decoded.
  372. func (m *Message) Write(tBuf []byte) (int, error) {
  373. m.Raw = append(m.Raw[:0], tBuf...)
  374. return len(tBuf), m.Decode()
  375. }
  376. // CloneTo clones m to b securing any further m mutations.
  377. func (m *Message) CloneTo(b *Message) error {
  378. b.Raw = append(b.Raw[:0], m.Raw...)
  379. return b.Decode()
  380. }
  381. // MessageClass is 8-bit representation of 2-bit class of STUN Message Class.
  382. type MessageClass byte
  383. // Possible values for message class in STUN Message Type.
  384. const (
  385. ClassRequest MessageClass = 0x00 // 0b00
  386. ClassIndication MessageClass = 0x01 // 0b01
  387. ClassSuccessResponse MessageClass = 0x02 // 0b10
  388. ClassErrorResponse MessageClass = 0x03 // 0b11
  389. )
  390. // Common STUN message types.
  391. var (
  392. // Binding request message type.
  393. BindingRequest = NewType(MethodBinding, ClassRequest) //nolint:gochecknoglobals
  394. // Binding success response message type
  395. BindingSuccess = NewType(MethodBinding, ClassSuccessResponse) //nolint:gochecknoglobals
  396. // Binding error response message type.
  397. BindingError = NewType(MethodBinding, ClassErrorResponse) //nolint:gochecknoglobals
  398. )
  399. func (c MessageClass) String() string {
  400. switch c {
  401. case ClassRequest:
  402. return "request"
  403. case ClassIndication:
  404. return "indication"
  405. case ClassSuccessResponse:
  406. return "success response"
  407. case ClassErrorResponse:
  408. return "error response"
  409. default:
  410. panic("unknown message class") //nolint
  411. }
  412. }
  413. // Method is uint16 representation of 12-bit STUN method.
  414. type Method uint16
  415. // Possible methods for STUN Message.
  416. const (
  417. MethodBinding Method = 0x001
  418. MethodAllocate Method = 0x003
  419. MethodRefresh Method = 0x004
  420. MethodSend Method = 0x006
  421. MethodData Method = 0x007
  422. MethodCreatePermission Method = 0x008
  423. MethodChannelBind Method = 0x009
  424. )
  425. // Methods from RFC 6062.
  426. const (
  427. MethodConnect Method = 0x000a
  428. MethodConnectionBind Method = 0x000b
  429. MethodConnectionAttempt Method = 0x000c
  430. )
  431. func methodName() map[Method]string {
  432. return map[Method]string{
  433. MethodBinding: "Binding",
  434. MethodAllocate: "Allocate",
  435. MethodRefresh: "Refresh",
  436. MethodSend: "Send",
  437. MethodData: "Data",
  438. MethodCreatePermission: "CreatePermission",
  439. MethodChannelBind: "ChannelBind",
  440. // RFC 6062.
  441. MethodConnect: "Connect",
  442. MethodConnectionBind: "ConnectionBind",
  443. MethodConnectionAttempt: "ConnectionAttempt",
  444. }
  445. }
  446. func (m Method) String() string {
  447. s, ok := methodName()[m]
  448. if !ok {
  449. // Falling back to hex representation.
  450. s = fmt.Sprintf("0x%x", uint16(m))
  451. }
  452. return s
  453. }
  454. // MessageType is STUN Message Type Field.
  455. type MessageType struct {
  456. Method Method // e.g. binding
  457. Class MessageClass // e.g. request
  458. }
  459. // AddTo sets m type to t.
  460. func (t MessageType) AddTo(m *Message) error {
  461. m.SetType(t)
  462. return nil
  463. }
  464. // NewType returns new message type with provided method and class.
  465. func NewType(method Method, class MessageClass) MessageType {
  466. return MessageType{
  467. Method: method,
  468. Class: class,
  469. }
  470. }
  471. const (
  472. methodABits = 0xf // 0b0000000000001111
  473. methodBBits = 0x70 // 0b0000000001110000
  474. methodDBits = 0xf80 // 0b0000111110000000
  475. methodBShift = 1
  476. methodDShift = 2
  477. firstBit = 0x1
  478. secondBit = 0x2
  479. c0Bit = firstBit
  480. c1Bit = secondBit
  481. classC0Shift = 4
  482. classC1Shift = 7
  483. )
  484. // Value returns bit representation of messageType.
  485. func (t MessageType) Value() uint16 {
  486. // 0 1
  487. // 2 3 4 5 6 7 8 9 0 1 2 3 4 5
  488. // +--+--+-+-+-+-+-+-+-+-+-+-+-+-+
  489. // |M |M |M|M|M|C|M|M|M|C|M|M|M|M|
  490. // |11|10|9|8|7|1|6|5|4|0|3|2|1|0|
  491. // +--+--+-+-+-+-+-+-+-+-+-+-+-+-+
  492. // Figure 3: Format of STUN Message Type Field
  493. // Warning: Abandon all hope ye who enter here.
  494. // Splitting M into A(M0-M3), B(M4-M6), D(M7-M11).
  495. m := uint16(t.Method)
  496. a := m & methodABits // A = M * 0b0000000000001111 (right 4 bits)
  497. b := m & methodBBits // B = M * 0b0000000001110000 (3 bits after A)
  498. d := m & methodDBits // D = M * 0b0000111110000000 (5 bits after B)
  499. // Shifting to add "holes" for C0 (at 4 bit) and C1 (8 bit).
  500. m = a + (b << methodBShift) + (d << methodDShift)
  501. // C0 is zero bit of C, C1 is first bit.
  502. // C0 = C * 0b01, C1 = (C * 0b10) >> 1
  503. // Ct = C0 << 4 + C1 << 8.
  504. // Optimizations: "((C * 0b10) >> 1) << 8" as "(C * 0b10) << 7"
  505. // We need C0 shifted by 4, and C1 by 8 to fit "11" and "7" positions
  506. // (see figure 3).
  507. c := uint16(t.Class)
  508. c0 := (c & c0Bit) << classC0Shift
  509. c1 := (c & c1Bit) << classC1Shift
  510. class := c0 + c1
  511. return m + class
  512. }
  513. // ReadValue decodes uint16 into MessageType.
  514. func (t *MessageType) ReadValue(v uint16) {
  515. // Decoding class.
  516. // We are taking first bit from v >> 4 and second from v >> 7.
  517. c0 := (v >> classC0Shift) & c0Bit
  518. c1 := (v >> classC1Shift) & c1Bit
  519. class := c0 + c1
  520. t.Class = MessageClass(class)
  521. // Decoding method.
  522. a := v & methodABits // A(M0-M3)
  523. b := (v >> methodBShift) & methodBBits // B(M4-M6)
  524. d := (v >> methodDShift) & methodDBits // D(M7-M11)
  525. m := a + b + d
  526. t.Method = Method(m)
  527. }
  528. func (t MessageType) String() string {
  529. return fmt.Sprintf("%s %s", t.Method, t.Class)
  530. }
  531. // Contains return true if message contain t attribute.
  532. func (m *Message) Contains(t AttrType) bool {
  533. for _, a := range m.Attributes {
  534. if a.Type == t {
  535. return true
  536. }
  537. }
  538. return false
  539. }
  540. type transactionIDValueSetter [TransactionIDSize]byte
  541. // NewTransactionIDSetter returns new Setter that sets message transaction id
  542. // to provided value.
  543. func NewTransactionIDSetter(value [TransactionIDSize]byte) Setter {
  544. return transactionIDValueSetter(value)
  545. }
  546. func (t transactionIDValueSetter) AddTo(m *Message) error {
  547. m.TransactionID = t
  548. m.WriteTransactionID()
  549. return nil
  550. }