123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622 |
- // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
- // SPDX-License-Identifier: MIT
- package stun
- import (
- "crypto/rand"
- "encoding/base64"
- "errors"
- "fmt"
- "io"
- )
- const (
- // magicCookie is fixed value that aids in distinguishing STUN packets
- // from packets of other protocols when STUN is multiplexed with those
- // other protocols on the same Port.
- //
- // The magic cookie field MUST contain the fixed value 0x2112A442 in
- // network byte order.
- //
- // Defined in "STUN Message Structure", section 6.
- magicCookie = 0x2112A442
- attributeHeaderSize = 4
- messageHeaderSize = 20
- // TransactionIDSize is length of transaction id array (in bytes).
- TransactionIDSize = 12 // 96 bit
- )
- // NewTransactionID returns new random transaction ID using crypto/rand
- // as source.
- func NewTransactionID() (b [TransactionIDSize]byte) {
- readFullOrPanic(rand.Reader, b[:])
- return b
- }
- // IsMessage returns true if b looks like STUN message.
- // Useful for multiplexing. IsMessage does not guarantee
- // that decoding will be successful.
- func IsMessage(b []byte) bool {
- return len(b) >= messageHeaderSize && bin.Uint32(b[4:8]) == magicCookie
- }
- // New returns *Message with pre-allocated Raw.
- func New() *Message {
- const defaultRawCapacity = 120
- return &Message{
- Raw: make([]byte, messageHeaderSize, defaultRawCapacity),
- }
- }
- // ErrDecodeToNil occurs on Decode(data, nil) call.
- var ErrDecodeToNil = errors.New("attempt to decode to nil message")
- // Decode decodes Message from data to m, returning error if any.
- func Decode(data []byte, m *Message) error {
- if m == nil {
- return ErrDecodeToNil
- }
- m.Raw = append(m.Raw[:0], data...)
- return m.Decode()
- }
- // Message represents a single STUN packet. It uses aggressive internal
- // buffering to enable zero-allocation encoding and decoding,
- // so there are some usage constraints:
- //
- // Message, its fields, results of m.Get or any attribute a.GetFrom
- // are valid only until Message.Raw is not modified.
- type Message struct {
- Type MessageType
- Length uint32 // len(Raw) not including header
- TransactionID [TransactionIDSize]byte
- Attributes Attributes
- Raw []byte
- }
- // MarshalBinary implements the encoding.BinaryMarshaler interface.
- func (m Message) MarshalBinary() (data []byte, err error) {
- // We can't return m.Raw, allocation is expected by implicit interface
- // contract induced by other implementations.
- b := make([]byte, len(m.Raw))
- copy(b, m.Raw)
- return b, nil
- }
- // UnmarshalBinary implements the encoding.BinaryUnmarshaler interface.
- func (m *Message) UnmarshalBinary(data []byte) error {
- // We can't retain data, copy is expected by interface contract.
- m.Raw = append(m.Raw[:0], data...)
- return m.Decode()
- }
- // GobEncode implements the gob.GobEncoder interface.
- func (m Message) GobEncode() ([]byte, error) {
- return m.MarshalBinary()
- }
- // GobDecode implements the gob.GobDecoder interface.
- func (m *Message) GobDecode(data []byte) error {
- return m.UnmarshalBinary(data)
- }
- // AddTo sets b.TransactionID to m.TransactionID.
- //
- // Implements Setter to aid in crafting responses.
- func (m *Message) AddTo(b *Message) error {
- b.TransactionID = m.TransactionID
- b.WriteTransactionID()
- return nil
- }
- // NewTransactionID sets m.TransactionID to random value from crypto/rand
- // and returns error if any.
- func (m *Message) NewTransactionID() error {
- _, err := io.ReadFull(rand.Reader, m.TransactionID[:])
- if err == nil {
- m.WriteTransactionID()
- }
- return err
- }
- func (m *Message) String() string {
- tID := base64.StdEncoding.EncodeToString(m.TransactionID[:])
- aInfo := ""
- for k, a := range m.Attributes {
- aInfo += fmt.Sprintf("attr%d=%s ", k, a.Type)
- }
- return fmt.Sprintf("%s l=%d attrs=%d id=%s, %s", m.Type, m.Length, len(m.Attributes), tID, aInfo)
- }
- // Reset resets Message, attributes and underlying buffer length.
- func (m *Message) Reset() {
- m.Raw = m.Raw[:0]
- m.Length = 0
- m.Attributes = m.Attributes[:0]
- }
- // grow ensures that internal buffer has n length.
- func (m *Message) grow(n int) {
- if len(m.Raw) >= n {
- return
- }
- if cap(m.Raw) >= n {
- m.Raw = m.Raw[:n]
- return
- }
- m.Raw = append(m.Raw, make([]byte, n-len(m.Raw))...)
- }
- // Add appends new attribute to message. Not goroutine-safe.
- //
- // Value of attribute is copied to internal buffer so
- // it is safe to reuse v.
- func (m *Message) Add(t AttrType, v []byte) {
- // Allocating buffer for TLV (type-length-value).
- // T = t, L = len(v), V = v.
- // m.Raw will look like:
- // [0:20] <- message header
- // [20:20+m.Length] <- existing message attributes
- // [20+m.Length:20+m.Length+len(v) + 4] <- allocated buffer for new TLV
- // [first:last] <- same as previous
- // [0 1|2 3|4 4 + len(v)] <- mapping for allocated buffer
- // T L V
- allocSize := attributeHeaderSize + len(v) // ~ len(TLV) = len(TL) + len(V)
- first := messageHeaderSize + int(m.Length) // first byte number
- last := first + allocSize // last byte number
- m.grow(last) // growing cap(Raw) to fit TLV
- m.Raw = m.Raw[:last] // now len(Raw) = last
- m.Length += uint32(allocSize) // rendering length change
- // Sub-slicing internal buffer to simplify encoding.
- buf := m.Raw[first:last] // slice for TLV
- value := buf[attributeHeaderSize:] // slice for V
- attr := RawAttribute{
- Type: t, // T
- Length: uint16(len(v)), // L
- Value: value, // V
- }
- // Encoding attribute TLV to allocated buffer.
- bin.PutUint16(buf[0:2], attr.Type.Value()) // T
- bin.PutUint16(buf[2:4], attr.Length) // L
- copy(value, v) // V
- // Checking that attribute value needs padding.
- if attr.Length%padding != 0 {
- // Performing padding.
- bytesToAdd := nearestPaddedValueLength(len(v)) - len(v)
- last += bytesToAdd
- m.grow(last)
- // setting all padding bytes to zero
- // to prevent data leak from previous
- // data in next bytesToAdd bytes
- buf = m.Raw[last-bytesToAdd : last]
- for i := range buf {
- buf[i] = 0
- }
- m.Raw = m.Raw[:last] // increasing buffer length
- m.Length += uint32(bytesToAdd) // rendering length change
- }
- m.Attributes = append(m.Attributes, attr)
- m.WriteLength()
- }
- func attrSliceEqual(a, b Attributes) bool {
- for _, attr := range a {
- found := false
- for _, attrB := range b {
- if attrB.Type != attr.Type {
- continue
- }
- if attrB.Equal(attr) {
- found = true
- break
- }
- }
- if !found {
- return false
- }
- }
- return true
- }
- func attrEqual(a, b Attributes) bool {
- if a == nil && b == nil {
- return true
- }
- if a == nil || b == nil {
- return false
- }
- if len(a) != len(b) {
- return false
- }
- if !attrSliceEqual(a, b) {
- return false
- }
- if !attrSliceEqual(b, a) {
- return false
- }
- return true
- }
- // Equal returns true if Message b equals to m.
- // Ignores m.Raw.
- func (m *Message) Equal(b *Message) bool {
- if m == nil && b == nil {
- return true
- }
- if m == nil || b == nil {
- return false
- }
- if m.Type != b.Type {
- return false
- }
- if m.TransactionID != b.TransactionID {
- return false
- }
- if m.Length != b.Length {
- return false
- }
- if !attrEqual(m.Attributes, b.Attributes) {
- return false
- }
- return true
- }
- // WriteLength writes m.Length to m.Raw.
- func (m *Message) WriteLength() {
- m.grow(4)
- bin.PutUint16(m.Raw[2:4], uint16(m.Length))
- }
- // WriteHeader writes header to underlying buffer. Not goroutine-safe.
- func (m *Message) WriteHeader() {
- m.grow(messageHeaderSize)
- _ = m.Raw[:messageHeaderSize] // early bounds check to guarantee safety of writes below
- m.WriteType()
- m.WriteLength()
- bin.PutUint32(m.Raw[4:8], magicCookie) // magic cookie
- copy(m.Raw[8:messageHeaderSize], m.TransactionID[:]) // transaction ID
- }
- // WriteTransactionID writes m.TransactionID to m.Raw.
- func (m *Message) WriteTransactionID() {
- copy(m.Raw[8:messageHeaderSize], m.TransactionID[:]) // transaction ID
- }
- // WriteAttributes encodes all m.Attributes to m.
- func (m *Message) WriteAttributes() {
- attributes := m.Attributes
- m.Attributes = attributes[:0]
- for _, a := range attributes {
- m.Add(a.Type, a.Value)
- }
- m.Attributes = attributes
- }
- // WriteType writes m.Type to m.Raw.
- func (m *Message) WriteType() {
- m.grow(2)
- bin.PutUint16(m.Raw[0:2], m.Type.Value()) // message type
- }
- // SetType sets m.Type and writes it to m.Raw.
- func (m *Message) SetType(t MessageType) {
- m.Type = t
- m.WriteType()
- }
- // Encode re-encodes message into m.Raw.
- func (m *Message) Encode() {
- m.Raw = m.Raw[:0]
- m.WriteHeader()
- m.Length = 0
- m.WriteAttributes()
- }
- // WriteTo implements WriterTo via calling Write(m.Raw) on w and returning
- // call result.
- func (m *Message) WriteTo(w io.Writer) (int64, error) {
- n, err := w.Write(m.Raw)
- return int64(n), err
- }
- // ReadFrom implements ReaderFrom. Reads message from r into m.Raw,
- // Decodes it and return error if any. If m.Raw is too small, will return
- // ErrUnexpectedEOF, ErrUnexpectedHeaderEOF or *DecodeErr.
- //
- // Can return *DecodeErr while decoding too.
- func (m *Message) ReadFrom(r io.Reader) (int64, error) {
- tBuf := m.Raw[:cap(m.Raw)]
- var (
- n int
- err error
- )
- if n, err = r.Read(tBuf); err != nil {
- return int64(n), err
- }
- m.Raw = tBuf[:n]
- return int64(n), m.Decode()
- }
- // ErrUnexpectedHeaderEOF means that there were not enough bytes in
- // m.Raw to read header.
- var ErrUnexpectedHeaderEOF = errors.New("unexpected EOF: not enough bytes to read header")
- // Decode decodes m.Raw into m.
- func (m *Message) Decode() error {
- // decoding message header
- buf := m.Raw
- if len(buf) < messageHeaderSize {
- return ErrUnexpectedHeaderEOF
- }
- var (
- t = bin.Uint16(buf[0:2]) // first 2 bytes
- size = int(bin.Uint16(buf[2:4])) // second 2 bytes
- cookie = bin.Uint32(buf[4:8]) // last 4 bytes
- fullSize = messageHeaderSize + size // len(m.Raw)
- )
- if cookie != magicCookie {
- msg := fmt.Sprintf("%x is invalid magic cookie (should be %x)", cookie, magicCookie)
- return newDecodeErr("message", "cookie", msg)
- }
- if len(buf) < fullSize {
- msg := fmt.Sprintf("buffer length %d is less than %d (expected message size)", len(buf), fullSize)
- return newAttrDecodeErr("message", msg)
- }
- // saving header data
- m.Type.ReadValue(t)
- m.Length = uint32(size)
- copy(m.TransactionID[:], buf[8:messageHeaderSize])
- m.Attributes = m.Attributes[:0]
- var (
- offset = 0
- b = buf[messageHeaderSize:fullSize]
- )
- for offset < size {
- // checking that we have enough bytes to read header
- if len(b) < attributeHeaderSize {
- msg := fmt.Sprintf("buffer length %d is less than %d (expected header size)", len(b), attributeHeaderSize)
- return newAttrDecodeErr("header", msg)
- }
- var (
- a = RawAttribute{
- Type: compatAttrType(bin.Uint16(b[0:2])), // first 2 bytes
- Length: bin.Uint16(b[2:4]), // second 2 bytes
- }
- aL = int(a.Length) // attribute length
- aBuffL = nearestPaddedValueLength(aL) // expected buffer length (with padding)
- )
- b = b[attributeHeaderSize:] // slicing again to simplify value read
- offset += attributeHeaderSize
- if len(b) < aBuffL { // checking size
- msg := fmt.Sprintf("buffer length %d is less than %d (expected value size for %s)", len(b), aBuffL, a.Type)
- return newAttrDecodeErr("value", msg)
- }
- a.Value = b[:aL]
- offset += aBuffL
- b = b[aBuffL:]
- m.Attributes = append(m.Attributes, a)
- }
- return nil
- }
- // Write decodes message and return error if any.
- //
- // Any error is unrecoverable, but message could be partially decoded.
- func (m *Message) Write(tBuf []byte) (int, error) {
- m.Raw = append(m.Raw[:0], tBuf...)
- return len(tBuf), m.Decode()
- }
- // CloneTo clones m to b securing any further m mutations.
- func (m *Message) CloneTo(b *Message) error {
- b.Raw = append(b.Raw[:0], m.Raw...)
- return b.Decode()
- }
- // MessageClass is 8-bit representation of 2-bit class of STUN Message Class.
- type MessageClass byte
- // Possible values for message class in STUN Message Type.
- const (
- ClassRequest MessageClass = 0x00 // 0b00
- ClassIndication MessageClass = 0x01 // 0b01
- ClassSuccessResponse MessageClass = 0x02 // 0b10
- ClassErrorResponse MessageClass = 0x03 // 0b11
- )
- // Common STUN message types.
- var (
- // Binding request message type.
- BindingRequest = NewType(MethodBinding, ClassRequest) //nolint:gochecknoglobals
- // Binding success response message type
- BindingSuccess = NewType(MethodBinding, ClassSuccessResponse) //nolint:gochecknoglobals
- // Binding error response message type.
- BindingError = NewType(MethodBinding, ClassErrorResponse) //nolint:gochecknoglobals
- )
- func (c MessageClass) String() string {
- switch c {
- case ClassRequest:
- return "request"
- case ClassIndication:
- return "indication"
- case ClassSuccessResponse:
- return "success response"
- case ClassErrorResponse:
- return "error response"
- default:
- panic("unknown message class") //nolint
- }
- }
- // Method is uint16 representation of 12-bit STUN method.
- type Method uint16
- // Possible methods for STUN Message.
- const (
- MethodBinding Method = 0x001
- MethodAllocate Method = 0x003
- MethodRefresh Method = 0x004
- MethodSend Method = 0x006
- MethodData Method = 0x007
- MethodCreatePermission Method = 0x008
- MethodChannelBind Method = 0x009
- )
- // Methods from RFC 6062.
- const (
- MethodConnect Method = 0x000a
- MethodConnectionBind Method = 0x000b
- MethodConnectionAttempt Method = 0x000c
- )
- func methodName() map[Method]string {
- return map[Method]string{
- MethodBinding: "Binding",
- MethodAllocate: "Allocate",
- MethodRefresh: "Refresh",
- MethodSend: "Send",
- MethodData: "Data",
- MethodCreatePermission: "CreatePermission",
- MethodChannelBind: "ChannelBind",
- // RFC 6062.
- MethodConnect: "Connect",
- MethodConnectionBind: "ConnectionBind",
- MethodConnectionAttempt: "ConnectionAttempt",
- }
- }
- func (m Method) String() string {
- s, ok := methodName()[m]
- if !ok {
- // Falling back to hex representation.
- s = fmt.Sprintf("0x%x", uint16(m))
- }
- return s
- }
- // MessageType is STUN Message Type Field.
- type MessageType struct {
- Method Method // e.g. binding
- Class MessageClass // e.g. request
- }
- // AddTo sets m type to t.
- func (t MessageType) AddTo(m *Message) error {
- m.SetType(t)
- return nil
- }
- // NewType returns new message type with provided method and class.
- func NewType(method Method, class MessageClass) MessageType {
- return MessageType{
- Method: method,
- Class: class,
- }
- }
- const (
- methodABits = 0xf // 0b0000000000001111
- methodBBits = 0x70 // 0b0000000001110000
- methodDBits = 0xf80 // 0b0000111110000000
- methodBShift = 1
- methodDShift = 2
- firstBit = 0x1
- secondBit = 0x2
- c0Bit = firstBit
- c1Bit = secondBit
- classC0Shift = 4
- classC1Shift = 7
- )
- // Value returns bit representation of messageType.
- func (t MessageType) Value() uint16 {
- // 0 1
- // 2 3 4 5 6 7 8 9 0 1 2 3 4 5
- // +--+--+-+-+-+-+-+-+-+-+-+-+-+-+
- // |M |M |M|M|M|C|M|M|M|C|M|M|M|M|
- // |11|10|9|8|7|1|6|5|4|0|3|2|1|0|
- // +--+--+-+-+-+-+-+-+-+-+-+-+-+-+
- // Figure 3: Format of STUN Message Type Field
- // Warning: Abandon all hope ye who enter here.
- // Splitting M into A(M0-M3), B(M4-M6), D(M7-M11).
- m := uint16(t.Method)
- a := m & methodABits // A = M * 0b0000000000001111 (right 4 bits)
- b := m & methodBBits // B = M * 0b0000000001110000 (3 bits after A)
- d := m & methodDBits // D = M * 0b0000111110000000 (5 bits after B)
- // Shifting to add "holes" for C0 (at 4 bit) and C1 (8 bit).
- m = a + (b << methodBShift) + (d << methodDShift)
- // C0 is zero bit of C, C1 is first bit.
- // C0 = C * 0b01, C1 = (C * 0b10) >> 1
- // Ct = C0 << 4 + C1 << 8.
- // Optimizations: "((C * 0b10) >> 1) << 8" as "(C * 0b10) << 7"
- // We need C0 shifted by 4, and C1 by 8 to fit "11" and "7" positions
- // (see figure 3).
- c := uint16(t.Class)
- c0 := (c & c0Bit) << classC0Shift
- c1 := (c & c1Bit) << classC1Shift
- class := c0 + c1
- return m + class
- }
- // ReadValue decodes uint16 into MessageType.
- func (t *MessageType) ReadValue(v uint16) {
- // Decoding class.
- // We are taking first bit from v >> 4 and second from v >> 7.
- c0 := (v >> classC0Shift) & c0Bit
- c1 := (v >> classC1Shift) & c1Bit
- class := c0 + c1
- t.Class = MessageClass(class)
- // Decoding method.
- a := v & methodABits // A(M0-M3)
- b := (v >> methodBShift) & methodBBits // B(M4-M6)
- d := (v >> methodDShift) & methodDBits // D(M7-M11)
- m := a + b + d
- t.Method = Method(m)
- }
- func (t MessageType) String() string {
- return fmt.Sprintf("%s %s", t.Method, t.Class)
- }
- // Contains return true if message contain t attribute.
- func (m *Message) Contains(t AttrType) bool {
- for _, a := range m.Attributes {
- if a.Type == t {
- return true
- }
- }
- return false
- }
- type transactionIDValueSetter [TransactionIDSize]byte
- // NewTransactionIDSetter returns new Setter that sets message transaction id
- // to provided value.
- func NewTransactionIDSetter(value [TransactionIDSize]byte) Setter {
- return transactionIDValueSetter(value)
- }
- func (t transactionIDValueSetter) AddTo(m *Message) error {
- m.TransactionID = t
- m.WriteTransactionID()
- return nil
- }
|