cipher.go 1.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. package wsutil
  2. import (
  3. "io"
  4. "github.com/gobwas/pool/pbytes"
  5. "github.com/gobwas/ws"
  6. )
  7. // CipherReader implements io.Reader that applies xor-cipher to the bytes read
  8. // from source.
  9. // It could help to unmask WebSocket frame payload on the fly.
  10. type CipherReader struct {
  11. r io.Reader
  12. mask [4]byte
  13. pos int
  14. }
  15. // NewCipherReader creates xor-cipher reader from r with given mask.
  16. func NewCipherReader(r io.Reader, mask [4]byte) *CipherReader {
  17. return &CipherReader{r, mask, 0}
  18. }
  19. // Reset resets CipherReader to read from r with given mask.
  20. func (c *CipherReader) Reset(r io.Reader, mask [4]byte) {
  21. c.r = r
  22. c.mask = mask
  23. c.pos = 0
  24. }
  25. // Read implements io.Reader interface. It applies mask given during
  26. // initialization to every read byte.
  27. func (c *CipherReader) Read(p []byte) (n int, err error) {
  28. n, err = c.r.Read(p)
  29. ws.Cipher(p[:n], c.mask, c.pos)
  30. c.pos += n
  31. return
  32. }
  33. // CipherWriter implements io.Writer that applies xor-cipher to the bytes
  34. // written to the destination writer. It does not modify the original bytes.
  35. type CipherWriter struct {
  36. w io.Writer
  37. mask [4]byte
  38. pos int
  39. }
  40. // NewCipherWriter creates xor-cipher writer to w with given mask.
  41. func NewCipherWriter(w io.Writer, mask [4]byte) *CipherWriter {
  42. return &CipherWriter{w, mask, 0}
  43. }
  44. // Reset reset CipherWriter to write to w with given mask.
  45. func (c *CipherWriter) Reset(w io.Writer, mask [4]byte) {
  46. c.w = w
  47. c.mask = mask
  48. c.pos = 0
  49. }
  50. // Write implements io.Writer interface. It applies masking during
  51. // initialization to every sent byte. It does not modify original slice.
  52. func (c *CipherWriter) Write(p []byte) (n int, err error) {
  53. cp := pbytes.GetLen(len(p))
  54. defer pbytes.Put(cp)
  55. copy(cp, p)
  56. ws.Cipher(cp, c.mask, c.pos)
  57. n, err = c.w.Write(cp)
  58. c.pos += n
  59. return
  60. }