variance_avx2.c 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789
  1. /*
  2. * Copyright (c) 2012 The WebM project authors. All Rights Reserved.
  3. *
  4. * Use of this source code is governed by a BSD-style license
  5. * that can be found in the LICENSE file in the root of the source
  6. * tree. An additional intellectual property rights grant can be found
  7. * in the file PATENTS. All contributing project authors may
  8. * be found in the AUTHORS file in the root of the source tree.
  9. */
  10. #include <immintrin.h> // AVX2
  11. #include "./vpx_dsp_rtcd.h"
  12. /* clang-format off */
  13. DECLARE_ALIGNED(32, static const uint8_t, bilinear_filters_avx2[512]) = {
  14. 16, 0, 16, 0, 16, 0, 16, 0, 16, 0, 16, 0, 16, 0, 16, 0,
  15. 16, 0, 16, 0, 16, 0, 16, 0, 16, 0, 16, 0, 16, 0, 16, 0,
  16. 14, 2, 14, 2, 14, 2, 14, 2, 14, 2, 14, 2, 14, 2, 14, 2,
  17. 14, 2, 14, 2, 14, 2, 14, 2, 14, 2, 14, 2, 14, 2, 14, 2,
  18. 12, 4, 12, 4, 12, 4, 12, 4, 12, 4, 12, 4, 12, 4, 12, 4,
  19. 12, 4, 12, 4, 12, 4, 12, 4, 12, 4, 12, 4, 12, 4, 12, 4,
  20. 10, 6, 10, 6, 10, 6, 10, 6, 10, 6, 10, 6, 10, 6, 10, 6,
  21. 10, 6, 10, 6, 10, 6, 10, 6, 10, 6, 10, 6, 10, 6, 10, 6,
  22. 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
  23. 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
  24. 6, 10, 6, 10, 6, 10, 6, 10, 6, 10, 6, 10, 6, 10, 6, 10,
  25. 6, 10, 6, 10, 6, 10, 6, 10, 6, 10, 6, 10, 6, 10, 6, 10,
  26. 4, 12, 4, 12, 4, 12, 4, 12, 4, 12, 4, 12, 4, 12, 4, 12,
  27. 4, 12, 4, 12, 4, 12, 4, 12, 4, 12, 4, 12, 4, 12, 4, 12,
  28. 2, 14, 2, 14, 2, 14, 2, 14, 2, 14, 2, 14, 2, 14, 2, 14,
  29. 2, 14, 2, 14, 2, 14, 2, 14, 2, 14, 2, 14, 2, 14, 2, 14,
  30. };
  31. DECLARE_ALIGNED(32, static const int8_t, adjacent_sub_avx2[32]) = {
  32. 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1,
  33. 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1
  34. };
  35. /* clang-format on */
  36. static INLINE void variance_kernel_avx2(const __m256i src, const __m256i ref,
  37. __m256i *const sse,
  38. __m256i *const sum) {
  39. const __m256i adj_sub = _mm256_load_si256((__m256i const *)adjacent_sub_avx2);
  40. // unpack into pairs of source and reference values
  41. const __m256i src_ref0 = _mm256_unpacklo_epi8(src, ref);
  42. const __m256i src_ref1 = _mm256_unpackhi_epi8(src, ref);
  43. // subtract adjacent elements using src*1 + ref*-1
  44. const __m256i diff0 = _mm256_maddubs_epi16(src_ref0, adj_sub);
  45. const __m256i diff1 = _mm256_maddubs_epi16(src_ref1, adj_sub);
  46. const __m256i madd0 = _mm256_madd_epi16(diff0, diff0);
  47. const __m256i madd1 = _mm256_madd_epi16(diff1, diff1);
  48. // add to the running totals
  49. *sum = _mm256_add_epi16(*sum, _mm256_add_epi16(diff0, diff1));
  50. *sse = _mm256_add_epi32(*sse, _mm256_add_epi32(madd0, madd1));
  51. }
  52. static INLINE void variance_final_from_32bit_sum_avx2(__m256i vsse,
  53. __m128i vsum,
  54. unsigned int *const sse,
  55. int *const sum) {
  56. // extract the low lane and add it to the high lane
  57. const __m128i sse_reg_128 = _mm_add_epi32(_mm256_castsi256_si128(vsse),
  58. _mm256_extractf128_si256(vsse, 1));
  59. // unpack sse and sum registers and add
  60. const __m128i sse_sum_lo = _mm_unpacklo_epi32(sse_reg_128, vsum);
  61. const __m128i sse_sum_hi = _mm_unpackhi_epi32(sse_reg_128, vsum);
  62. const __m128i sse_sum = _mm_add_epi32(sse_sum_lo, sse_sum_hi);
  63. // perform the final summation and extract the results
  64. const __m128i res = _mm_add_epi32(sse_sum, _mm_srli_si128(sse_sum, 8));
  65. *((int *)sse) = _mm_cvtsi128_si32(res);
  66. *((int *)sum) = _mm_extract_epi32(res, 1);
  67. }
  68. static INLINE void variance_final_from_16bit_sum_avx2(__m256i vsse,
  69. __m256i vsum,
  70. unsigned int *const sse,
  71. int *const sum) {
  72. // extract the low lane and add it to the high lane
  73. const __m128i sum_reg_128 = _mm_add_epi16(_mm256_castsi256_si128(vsum),
  74. _mm256_extractf128_si256(vsum, 1));
  75. const __m128i sum_reg_64 =
  76. _mm_add_epi16(sum_reg_128, _mm_srli_si128(sum_reg_128, 8));
  77. const __m128i sum_int32 = _mm_cvtepi16_epi32(sum_reg_64);
  78. variance_final_from_32bit_sum_avx2(vsse, sum_int32, sse, sum);
  79. }
  80. static INLINE __m256i sum_to_32bit_avx2(const __m256i sum) {
  81. const __m256i sum_lo = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(sum));
  82. const __m256i sum_hi =
  83. _mm256_cvtepi16_epi32(_mm256_extractf128_si256(sum, 1));
  84. return _mm256_add_epi32(sum_lo, sum_hi);
  85. }
  86. static INLINE void variance16_kernel_avx2(
  87. const uint8_t *const src, const int src_stride, const uint8_t *const ref,
  88. const int ref_stride, __m256i *const sse, __m256i *const sum) {
  89. const __m128i s0 = _mm_loadu_si128((__m128i const *)(src + 0 * src_stride));
  90. const __m128i s1 = _mm_loadu_si128((__m128i const *)(src + 1 * src_stride));
  91. const __m128i r0 = _mm_loadu_si128((__m128i const *)(ref + 0 * ref_stride));
  92. const __m128i r1 = _mm_loadu_si128((__m128i const *)(ref + 1 * ref_stride));
  93. const __m256i s = _mm256_inserti128_si256(_mm256_castsi128_si256(s0), s1, 1);
  94. const __m256i r = _mm256_inserti128_si256(_mm256_castsi128_si256(r0), r1, 1);
  95. variance_kernel_avx2(s, r, sse, sum);
  96. }
  97. static INLINE void variance32_kernel_avx2(const uint8_t *const src,
  98. const uint8_t *const ref,
  99. __m256i *const sse,
  100. __m256i *const sum) {
  101. const __m256i s = _mm256_loadu_si256((__m256i const *)(src));
  102. const __m256i r = _mm256_loadu_si256((__m256i const *)(ref));
  103. variance_kernel_avx2(s, r, sse, sum);
  104. }
  105. static INLINE void variance16_avx2(const uint8_t *src, const int src_stride,
  106. const uint8_t *ref, const int ref_stride,
  107. const int h, __m256i *const vsse,
  108. __m256i *const vsum) {
  109. int i;
  110. *vsum = _mm256_setzero_si256();
  111. *vsse = _mm256_setzero_si256();
  112. for (i = 0; i < h; i += 2) {
  113. variance16_kernel_avx2(src, src_stride, ref, ref_stride, vsse, vsum);
  114. src += 2 * src_stride;
  115. ref += 2 * ref_stride;
  116. }
  117. }
  118. static INLINE void variance32_avx2(const uint8_t *src, const int src_stride,
  119. const uint8_t *ref, const int ref_stride,
  120. const int h, __m256i *const vsse,
  121. __m256i *const vsum) {
  122. int i;
  123. *vsum = _mm256_setzero_si256();
  124. *vsse = _mm256_setzero_si256();
  125. for (i = 0; i < h; i++) {
  126. variance32_kernel_avx2(src, ref, vsse, vsum);
  127. src += src_stride;
  128. ref += ref_stride;
  129. }
  130. }
  131. static INLINE void variance64_avx2(const uint8_t *src, const int src_stride,
  132. const uint8_t *ref, const int ref_stride,
  133. const int h, __m256i *const vsse,
  134. __m256i *const vsum) {
  135. int i;
  136. *vsum = _mm256_setzero_si256();
  137. for (i = 0; i < h; i++) {
  138. variance32_kernel_avx2(src + 0, ref + 0, vsse, vsum);
  139. variance32_kernel_avx2(src + 32, ref + 32, vsse, vsum);
  140. src += src_stride;
  141. ref += ref_stride;
  142. }
  143. }
  144. void vpx_get16x16var_avx2(const uint8_t *src_ptr, int src_stride,
  145. const uint8_t *ref_ptr, int ref_stride,
  146. unsigned int *sse, int *sum) {
  147. __m256i vsse, vsum;
  148. variance16_avx2(src_ptr, src_stride, ref_ptr, ref_stride, 16, &vsse, &vsum);
  149. variance_final_from_16bit_sum_avx2(vsse, vsum, sse, sum);
  150. }
  151. #define FILTER_SRC(filter) \
  152. /* filter the source */ \
  153. exp_src_lo = _mm256_maddubs_epi16(exp_src_lo, filter); \
  154. exp_src_hi = _mm256_maddubs_epi16(exp_src_hi, filter); \
  155. \
  156. /* add 8 to source */ \
  157. exp_src_lo = _mm256_add_epi16(exp_src_lo, pw8); \
  158. exp_src_hi = _mm256_add_epi16(exp_src_hi, pw8); \
  159. \
  160. /* divide source by 16 */ \
  161. exp_src_lo = _mm256_srai_epi16(exp_src_lo, 4); \
  162. exp_src_hi = _mm256_srai_epi16(exp_src_hi, 4);
  163. #define CALC_SUM_SSE_INSIDE_LOOP \
  164. /* expand each byte to 2 bytes */ \
  165. exp_dst_lo = _mm256_unpacklo_epi8(dst_reg, zero_reg); \
  166. exp_dst_hi = _mm256_unpackhi_epi8(dst_reg, zero_reg); \
  167. /* source - dest */ \
  168. exp_src_lo = _mm256_sub_epi16(exp_src_lo, exp_dst_lo); \
  169. exp_src_hi = _mm256_sub_epi16(exp_src_hi, exp_dst_hi); \
  170. /* caculate sum */ \
  171. *sum_reg = _mm256_add_epi16(*sum_reg, exp_src_lo); \
  172. exp_src_lo = _mm256_madd_epi16(exp_src_lo, exp_src_lo); \
  173. *sum_reg = _mm256_add_epi16(*sum_reg, exp_src_hi); \
  174. exp_src_hi = _mm256_madd_epi16(exp_src_hi, exp_src_hi); \
  175. /* calculate sse */ \
  176. *sse_reg = _mm256_add_epi32(*sse_reg, exp_src_lo); \
  177. *sse_reg = _mm256_add_epi32(*sse_reg, exp_src_hi);
  178. // final calculation to sum and sse
  179. #define CALC_SUM_AND_SSE \
  180. res_cmp = _mm256_cmpgt_epi16(zero_reg, sum_reg); \
  181. sse_reg_hi = _mm256_srli_si256(sse_reg, 8); \
  182. sum_reg_lo = _mm256_unpacklo_epi16(sum_reg, res_cmp); \
  183. sum_reg_hi = _mm256_unpackhi_epi16(sum_reg, res_cmp); \
  184. sse_reg = _mm256_add_epi32(sse_reg, sse_reg_hi); \
  185. sum_reg = _mm256_add_epi32(sum_reg_lo, sum_reg_hi); \
  186. \
  187. sse_reg_hi = _mm256_srli_si256(sse_reg, 4); \
  188. sum_reg_hi = _mm256_srli_si256(sum_reg, 8); \
  189. \
  190. sse_reg = _mm256_add_epi32(sse_reg, sse_reg_hi); \
  191. sum_reg = _mm256_add_epi32(sum_reg, sum_reg_hi); \
  192. *((int *)sse) = _mm_cvtsi128_si32(_mm256_castsi256_si128(sse_reg)) + \
  193. _mm_cvtsi128_si32(_mm256_extractf128_si256(sse_reg, 1)); \
  194. sum_reg_hi = _mm256_srli_si256(sum_reg, 4); \
  195. sum_reg = _mm256_add_epi32(sum_reg, sum_reg_hi); \
  196. sum = _mm_cvtsi128_si32(_mm256_castsi256_si128(sum_reg)) + \
  197. _mm_cvtsi128_si32(_mm256_extractf128_si256(sum_reg, 1));
  198. static INLINE void spv32_x0_y0(const uint8_t *src, int src_stride,
  199. const uint8_t *dst, int dst_stride,
  200. const uint8_t *second_pred, int second_stride,
  201. int do_sec, int height, __m256i *sum_reg,
  202. __m256i *sse_reg) {
  203. const __m256i zero_reg = _mm256_setzero_si256();
  204. __m256i exp_src_lo, exp_src_hi, exp_dst_lo, exp_dst_hi;
  205. int i;
  206. for (i = 0; i < height; i++) {
  207. const __m256i dst_reg = _mm256_loadu_si256((__m256i const *)dst);
  208. const __m256i src_reg = _mm256_loadu_si256((__m256i const *)src);
  209. if (do_sec) {
  210. const __m256i sec_reg = _mm256_loadu_si256((__m256i const *)second_pred);
  211. const __m256i avg_reg = _mm256_avg_epu8(src_reg, sec_reg);
  212. exp_src_lo = _mm256_unpacklo_epi8(avg_reg, zero_reg);
  213. exp_src_hi = _mm256_unpackhi_epi8(avg_reg, zero_reg);
  214. second_pred += second_stride;
  215. } else {
  216. exp_src_lo = _mm256_unpacklo_epi8(src_reg, zero_reg);
  217. exp_src_hi = _mm256_unpackhi_epi8(src_reg, zero_reg);
  218. }
  219. CALC_SUM_SSE_INSIDE_LOOP
  220. src += src_stride;
  221. dst += dst_stride;
  222. }
  223. }
  224. // (x == 0, y == 4) or (x == 4, y == 0). sstep determines the direction.
  225. static INLINE void spv32_half_zero(const uint8_t *src, int src_stride,
  226. const uint8_t *dst, int dst_stride,
  227. const uint8_t *second_pred,
  228. int second_stride, int do_sec, int height,
  229. __m256i *sum_reg, __m256i *sse_reg,
  230. int sstep) {
  231. const __m256i zero_reg = _mm256_setzero_si256();
  232. __m256i exp_src_lo, exp_src_hi, exp_dst_lo, exp_dst_hi;
  233. int i;
  234. for (i = 0; i < height; i++) {
  235. const __m256i dst_reg = _mm256_loadu_si256((__m256i const *)dst);
  236. const __m256i src_0 = _mm256_loadu_si256((__m256i const *)src);
  237. const __m256i src_1 = _mm256_loadu_si256((__m256i const *)(src + sstep));
  238. const __m256i src_avg = _mm256_avg_epu8(src_0, src_1);
  239. if (do_sec) {
  240. const __m256i sec_reg = _mm256_loadu_si256((__m256i const *)second_pred);
  241. const __m256i avg_reg = _mm256_avg_epu8(src_avg, sec_reg);
  242. exp_src_lo = _mm256_unpacklo_epi8(avg_reg, zero_reg);
  243. exp_src_hi = _mm256_unpackhi_epi8(avg_reg, zero_reg);
  244. second_pred += second_stride;
  245. } else {
  246. exp_src_lo = _mm256_unpacklo_epi8(src_avg, zero_reg);
  247. exp_src_hi = _mm256_unpackhi_epi8(src_avg, zero_reg);
  248. }
  249. CALC_SUM_SSE_INSIDE_LOOP
  250. src += src_stride;
  251. dst += dst_stride;
  252. }
  253. }
  254. static INLINE void spv32_x0_y4(const uint8_t *src, int src_stride,
  255. const uint8_t *dst, int dst_stride,
  256. const uint8_t *second_pred, int second_stride,
  257. int do_sec, int height, __m256i *sum_reg,
  258. __m256i *sse_reg) {
  259. spv32_half_zero(src, src_stride, dst, dst_stride, second_pred, second_stride,
  260. do_sec, height, sum_reg, sse_reg, src_stride);
  261. }
  262. static INLINE void spv32_x4_y0(const uint8_t *src, int src_stride,
  263. const uint8_t *dst, int dst_stride,
  264. const uint8_t *second_pred, int second_stride,
  265. int do_sec, int height, __m256i *sum_reg,
  266. __m256i *sse_reg) {
  267. spv32_half_zero(src, src_stride, dst, dst_stride, second_pred, second_stride,
  268. do_sec, height, sum_reg, sse_reg, 1);
  269. }
  270. static INLINE void spv32_x4_y4(const uint8_t *src, int src_stride,
  271. const uint8_t *dst, int dst_stride,
  272. const uint8_t *second_pred, int second_stride,
  273. int do_sec, int height, __m256i *sum_reg,
  274. __m256i *sse_reg) {
  275. const __m256i zero_reg = _mm256_setzero_si256();
  276. const __m256i src_a = _mm256_loadu_si256((__m256i const *)src);
  277. const __m256i src_b = _mm256_loadu_si256((__m256i const *)(src + 1));
  278. __m256i prev_src_avg = _mm256_avg_epu8(src_a, src_b);
  279. __m256i exp_src_lo, exp_src_hi, exp_dst_lo, exp_dst_hi;
  280. int i;
  281. src += src_stride;
  282. for (i = 0; i < height; i++) {
  283. const __m256i dst_reg = _mm256_loadu_si256((__m256i const *)dst);
  284. const __m256i src_0 = _mm256_loadu_si256((__m256i const *)(src));
  285. const __m256i src_1 = _mm256_loadu_si256((__m256i const *)(src + 1));
  286. const __m256i src_avg = _mm256_avg_epu8(src_0, src_1);
  287. const __m256i current_avg = _mm256_avg_epu8(prev_src_avg, src_avg);
  288. prev_src_avg = src_avg;
  289. if (do_sec) {
  290. const __m256i sec_reg = _mm256_loadu_si256((__m256i const *)second_pred);
  291. const __m256i avg_reg = _mm256_avg_epu8(current_avg, sec_reg);
  292. exp_src_lo = _mm256_unpacklo_epi8(avg_reg, zero_reg);
  293. exp_src_hi = _mm256_unpackhi_epi8(avg_reg, zero_reg);
  294. second_pred += second_stride;
  295. } else {
  296. exp_src_lo = _mm256_unpacklo_epi8(current_avg, zero_reg);
  297. exp_src_hi = _mm256_unpackhi_epi8(current_avg, zero_reg);
  298. }
  299. // save current source average
  300. CALC_SUM_SSE_INSIDE_LOOP
  301. dst += dst_stride;
  302. src += src_stride;
  303. }
  304. }
  305. // (x == 0, y == bil) or (x == 4, y == bil). sstep determines the direction.
  306. static INLINE void spv32_bilin_zero(const uint8_t *src, int src_stride,
  307. const uint8_t *dst, int dst_stride,
  308. const uint8_t *second_pred,
  309. int second_stride, int do_sec, int height,
  310. __m256i *sum_reg, __m256i *sse_reg,
  311. int offset, int sstep) {
  312. const __m256i zero_reg = _mm256_setzero_si256();
  313. const __m256i pw8 = _mm256_set1_epi16(8);
  314. const __m256i filter = _mm256_load_si256(
  315. (__m256i const *)(bilinear_filters_avx2 + (offset << 5)));
  316. __m256i exp_src_lo, exp_src_hi, exp_dst_lo, exp_dst_hi;
  317. int i;
  318. for (i = 0; i < height; i++) {
  319. const __m256i dst_reg = _mm256_loadu_si256((__m256i const *)dst);
  320. const __m256i src_0 = _mm256_loadu_si256((__m256i const *)src);
  321. const __m256i src_1 = _mm256_loadu_si256((__m256i const *)(src + sstep));
  322. exp_src_lo = _mm256_unpacklo_epi8(src_0, src_1);
  323. exp_src_hi = _mm256_unpackhi_epi8(src_0, src_1);
  324. FILTER_SRC(filter)
  325. if (do_sec) {
  326. const __m256i sec_reg = _mm256_loadu_si256((__m256i const *)second_pred);
  327. const __m256i exp_src = _mm256_packus_epi16(exp_src_lo, exp_src_hi);
  328. const __m256i avg_reg = _mm256_avg_epu8(exp_src, sec_reg);
  329. second_pred += second_stride;
  330. exp_src_lo = _mm256_unpacklo_epi8(avg_reg, zero_reg);
  331. exp_src_hi = _mm256_unpackhi_epi8(avg_reg, zero_reg);
  332. }
  333. CALC_SUM_SSE_INSIDE_LOOP
  334. src += src_stride;
  335. dst += dst_stride;
  336. }
  337. }
  338. static INLINE void spv32_x0_yb(const uint8_t *src, int src_stride,
  339. const uint8_t *dst, int dst_stride,
  340. const uint8_t *second_pred, int second_stride,
  341. int do_sec, int height, __m256i *sum_reg,
  342. __m256i *sse_reg, int y_offset) {
  343. spv32_bilin_zero(src, src_stride, dst, dst_stride, second_pred, second_stride,
  344. do_sec, height, sum_reg, sse_reg, y_offset, src_stride);
  345. }
  346. static INLINE void spv32_xb_y0(const uint8_t *src, int src_stride,
  347. const uint8_t *dst, int dst_stride,
  348. const uint8_t *second_pred, int second_stride,
  349. int do_sec, int height, __m256i *sum_reg,
  350. __m256i *sse_reg, int x_offset) {
  351. spv32_bilin_zero(src, src_stride, dst, dst_stride, second_pred, second_stride,
  352. do_sec, height, sum_reg, sse_reg, x_offset, 1);
  353. }
  354. static INLINE void spv32_x4_yb(const uint8_t *src, int src_stride,
  355. const uint8_t *dst, int dst_stride,
  356. const uint8_t *second_pred, int second_stride,
  357. int do_sec, int height, __m256i *sum_reg,
  358. __m256i *sse_reg, int y_offset) {
  359. const __m256i zero_reg = _mm256_setzero_si256();
  360. const __m256i pw8 = _mm256_set1_epi16(8);
  361. const __m256i filter = _mm256_load_si256(
  362. (__m256i const *)(bilinear_filters_avx2 + (y_offset << 5)));
  363. const __m256i src_a = _mm256_loadu_si256((__m256i const *)src);
  364. const __m256i src_b = _mm256_loadu_si256((__m256i const *)(src + 1));
  365. __m256i prev_src_avg = _mm256_avg_epu8(src_a, src_b);
  366. __m256i exp_src_lo, exp_src_hi, exp_dst_lo, exp_dst_hi;
  367. int i;
  368. src += src_stride;
  369. for (i = 0; i < height; i++) {
  370. const __m256i dst_reg = _mm256_loadu_si256((__m256i const *)dst);
  371. const __m256i src_0 = _mm256_loadu_si256((__m256i const *)src);
  372. const __m256i src_1 = _mm256_loadu_si256((__m256i const *)(src + 1));
  373. const __m256i src_avg = _mm256_avg_epu8(src_0, src_1);
  374. exp_src_lo = _mm256_unpacklo_epi8(prev_src_avg, src_avg);
  375. exp_src_hi = _mm256_unpackhi_epi8(prev_src_avg, src_avg);
  376. prev_src_avg = src_avg;
  377. FILTER_SRC(filter)
  378. if (do_sec) {
  379. const __m256i sec_reg = _mm256_loadu_si256((__m256i const *)second_pred);
  380. const __m256i exp_src_avg = _mm256_packus_epi16(exp_src_lo, exp_src_hi);
  381. const __m256i avg_reg = _mm256_avg_epu8(exp_src_avg, sec_reg);
  382. exp_src_lo = _mm256_unpacklo_epi8(avg_reg, zero_reg);
  383. exp_src_hi = _mm256_unpackhi_epi8(avg_reg, zero_reg);
  384. second_pred += second_stride;
  385. }
  386. CALC_SUM_SSE_INSIDE_LOOP
  387. dst += dst_stride;
  388. src += src_stride;
  389. }
  390. }
  391. static INLINE void spv32_xb_y4(const uint8_t *src, int src_stride,
  392. const uint8_t *dst, int dst_stride,
  393. const uint8_t *second_pred, int second_stride,
  394. int do_sec, int height, __m256i *sum_reg,
  395. __m256i *sse_reg, int x_offset) {
  396. const __m256i zero_reg = _mm256_setzero_si256();
  397. const __m256i pw8 = _mm256_set1_epi16(8);
  398. const __m256i filter = _mm256_load_si256(
  399. (__m256i const *)(bilinear_filters_avx2 + (x_offset << 5)));
  400. const __m256i src_a = _mm256_loadu_si256((__m256i const *)src);
  401. const __m256i src_b = _mm256_loadu_si256((__m256i const *)(src + 1));
  402. __m256i exp_src_lo, exp_src_hi, exp_dst_lo, exp_dst_hi;
  403. __m256i src_reg, src_pack;
  404. int i;
  405. exp_src_lo = _mm256_unpacklo_epi8(src_a, src_b);
  406. exp_src_hi = _mm256_unpackhi_epi8(src_a, src_b);
  407. FILTER_SRC(filter)
  408. // convert each 16 bit to 8 bit to each low and high lane source
  409. src_pack = _mm256_packus_epi16(exp_src_lo, exp_src_hi);
  410. src += src_stride;
  411. for (i = 0; i < height; i++) {
  412. const __m256i dst_reg = _mm256_loadu_si256((__m256i const *)dst);
  413. const __m256i src_0 = _mm256_loadu_si256((__m256i const *)src);
  414. const __m256i src_1 = _mm256_loadu_si256((__m256i const *)(src + 1));
  415. exp_src_lo = _mm256_unpacklo_epi8(src_0, src_1);
  416. exp_src_hi = _mm256_unpackhi_epi8(src_0, src_1);
  417. FILTER_SRC(filter)
  418. src_reg = _mm256_packus_epi16(exp_src_lo, exp_src_hi);
  419. // average between previous pack to the current
  420. src_pack = _mm256_avg_epu8(src_pack, src_reg);
  421. if (do_sec) {
  422. const __m256i sec_reg = _mm256_loadu_si256((__m256i const *)second_pred);
  423. const __m256i avg_pack = _mm256_avg_epu8(src_pack, sec_reg);
  424. exp_src_lo = _mm256_unpacklo_epi8(avg_pack, zero_reg);
  425. exp_src_hi = _mm256_unpackhi_epi8(avg_pack, zero_reg);
  426. second_pred += second_stride;
  427. } else {
  428. exp_src_lo = _mm256_unpacklo_epi8(src_pack, zero_reg);
  429. exp_src_hi = _mm256_unpackhi_epi8(src_pack, zero_reg);
  430. }
  431. CALC_SUM_SSE_INSIDE_LOOP
  432. src_pack = src_reg;
  433. dst += dst_stride;
  434. src += src_stride;
  435. }
  436. }
  437. static INLINE void spv32_xb_yb(const uint8_t *src, int src_stride,
  438. const uint8_t *dst, int dst_stride,
  439. const uint8_t *second_pred, int second_stride,
  440. int do_sec, int height, __m256i *sum_reg,
  441. __m256i *sse_reg, int x_offset, int y_offset) {
  442. const __m256i zero_reg = _mm256_setzero_si256();
  443. const __m256i pw8 = _mm256_set1_epi16(8);
  444. const __m256i xfilter = _mm256_load_si256(
  445. (__m256i const *)(bilinear_filters_avx2 + (x_offset << 5)));
  446. const __m256i yfilter = _mm256_load_si256(
  447. (__m256i const *)(bilinear_filters_avx2 + (y_offset << 5)));
  448. const __m256i src_a = _mm256_loadu_si256((__m256i const *)src);
  449. const __m256i src_b = _mm256_loadu_si256((__m256i const *)(src + 1));
  450. __m256i exp_src_lo, exp_src_hi, exp_dst_lo, exp_dst_hi;
  451. __m256i prev_src_pack, src_pack;
  452. int i;
  453. exp_src_lo = _mm256_unpacklo_epi8(src_a, src_b);
  454. exp_src_hi = _mm256_unpackhi_epi8(src_a, src_b);
  455. FILTER_SRC(xfilter)
  456. // convert each 16 bit to 8 bit to each low and high lane source
  457. prev_src_pack = _mm256_packus_epi16(exp_src_lo, exp_src_hi);
  458. src += src_stride;
  459. for (i = 0; i < height; i++) {
  460. const __m256i dst_reg = _mm256_loadu_si256((__m256i const *)dst);
  461. const __m256i src_0 = _mm256_loadu_si256((__m256i const *)src);
  462. const __m256i src_1 = _mm256_loadu_si256((__m256i const *)(src + 1));
  463. exp_src_lo = _mm256_unpacklo_epi8(src_0, src_1);
  464. exp_src_hi = _mm256_unpackhi_epi8(src_0, src_1);
  465. FILTER_SRC(xfilter)
  466. src_pack = _mm256_packus_epi16(exp_src_lo, exp_src_hi);
  467. // merge previous pack to current pack source
  468. exp_src_lo = _mm256_unpacklo_epi8(prev_src_pack, src_pack);
  469. exp_src_hi = _mm256_unpackhi_epi8(prev_src_pack, src_pack);
  470. FILTER_SRC(yfilter)
  471. if (do_sec) {
  472. const __m256i sec_reg = _mm256_loadu_si256((__m256i const *)second_pred);
  473. const __m256i exp_src = _mm256_packus_epi16(exp_src_lo, exp_src_hi);
  474. const __m256i avg_reg = _mm256_avg_epu8(exp_src, sec_reg);
  475. exp_src_lo = _mm256_unpacklo_epi8(avg_reg, zero_reg);
  476. exp_src_hi = _mm256_unpackhi_epi8(avg_reg, zero_reg);
  477. second_pred += second_stride;
  478. }
  479. prev_src_pack = src_pack;
  480. CALC_SUM_SSE_INSIDE_LOOP
  481. dst += dst_stride;
  482. src += src_stride;
  483. }
  484. }
  485. static INLINE int sub_pix_var32xh(const uint8_t *src, int src_stride,
  486. int x_offset, int y_offset,
  487. const uint8_t *dst, int dst_stride,
  488. const uint8_t *second_pred, int second_stride,
  489. int do_sec, int height, unsigned int *sse) {
  490. const __m256i zero_reg = _mm256_setzero_si256();
  491. __m256i sum_reg = _mm256_setzero_si256();
  492. __m256i sse_reg = _mm256_setzero_si256();
  493. __m256i sse_reg_hi, res_cmp, sum_reg_lo, sum_reg_hi;
  494. int sum;
  495. // x_offset = 0 and y_offset = 0
  496. if (x_offset == 0) {
  497. if (y_offset == 0) {
  498. spv32_x0_y0(src, src_stride, dst, dst_stride, second_pred, second_stride,
  499. do_sec, height, &sum_reg, &sse_reg);
  500. // x_offset = 0 and y_offset = 4
  501. } else if (y_offset == 4) {
  502. spv32_x0_y4(src, src_stride, dst, dst_stride, second_pred, second_stride,
  503. do_sec, height, &sum_reg, &sse_reg);
  504. // x_offset = 0 and y_offset = bilin interpolation
  505. } else {
  506. spv32_x0_yb(src, src_stride, dst, dst_stride, second_pred, second_stride,
  507. do_sec, height, &sum_reg, &sse_reg, y_offset);
  508. }
  509. // x_offset = 4 and y_offset = 0
  510. } else if (x_offset == 4) {
  511. if (y_offset == 0) {
  512. spv32_x4_y0(src, src_stride, dst, dst_stride, second_pred, second_stride,
  513. do_sec, height, &sum_reg, &sse_reg);
  514. // x_offset = 4 and y_offset = 4
  515. } else if (y_offset == 4) {
  516. spv32_x4_y4(src, src_stride, dst, dst_stride, second_pred, second_stride,
  517. do_sec, height, &sum_reg, &sse_reg);
  518. // x_offset = 4 and y_offset = bilin interpolation
  519. } else {
  520. spv32_x4_yb(src, src_stride, dst, dst_stride, second_pred, second_stride,
  521. do_sec, height, &sum_reg, &sse_reg, y_offset);
  522. }
  523. // x_offset = bilin interpolation and y_offset = 0
  524. } else {
  525. if (y_offset == 0) {
  526. spv32_xb_y0(src, src_stride, dst, dst_stride, second_pred, second_stride,
  527. do_sec, height, &sum_reg, &sse_reg, x_offset);
  528. // x_offset = bilin interpolation and y_offset = 4
  529. } else if (y_offset == 4) {
  530. spv32_xb_y4(src, src_stride, dst, dst_stride, second_pred, second_stride,
  531. do_sec, height, &sum_reg, &sse_reg, x_offset);
  532. // x_offset = bilin interpolation and y_offset = bilin interpolation
  533. } else {
  534. spv32_xb_yb(src, src_stride, dst, dst_stride, second_pred, second_stride,
  535. do_sec, height, &sum_reg, &sse_reg, x_offset, y_offset);
  536. }
  537. }
  538. CALC_SUM_AND_SSE
  539. return sum;
  540. }
  541. static unsigned int sub_pixel_variance32xh_avx2(
  542. const uint8_t *src, int src_stride, int x_offset, int y_offset,
  543. const uint8_t *dst, int dst_stride, int height, unsigned int *sse) {
  544. return sub_pix_var32xh(src, src_stride, x_offset, y_offset, dst, dst_stride,
  545. NULL, 0, 0, height, sse);
  546. }
  547. static unsigned int sub_pixel_avg_variance32xh_avx2(
  548. const uint8_t *src, int src_stride, int x_offset, int y_offset,
  549. const uint8_t *dst, int dst_stride, const uint8_t *second_pred,
  550. int second_stride, int height, unsigned int *sse) {
  551. return sub_pix_var32xh(src, src_stride, x_offset, y_offset, dst, dst_stride,
  552. second_pred, second_stride, 1, height, sse);
  553. }
  554. typedef void (*get_var_avx2)(const uint8_t *src_ptr, int src_stride,
  555. const uint8_t *ref_ptr, int ref_stride,
  556. unsigned int *sse, int *sum);
  557. unsigned int vpx_variance16x8_avx2(const uint8_t *src_ptr, int src_stride,
  558. const uint8_t *ref_ptr, int ref_stride,
  559. unsigned int *sse) {
  560. int sum;
  561. __m256i vsse, vsum;
  562. variance16_avx2(src_ptr, src_stride, ref_ptr, ref_stride, 8, &vsse, &vsum);
  563. variance_final_from_16bit_sum_avx2(vsse, vsum, sse, &sum);
  564. return *sse - (uint32_t)(((int64_t)sum * sum) >> 7);
  565. }
  566. unsigned int vpx_variance16x16_avx2(const uint8_t *src_ptr, int src_stride,
  567. const uint8_t *ref_ptr, int ref_stride,
  568. unsigned int *sse) {
  569. int sum;
  570. __m256i vsse, vsum;
  571. variance16_avx2(src_ptr, src_stride, ref_ptr, ref_stride, 16, &vsse, &vsum);
  572. variance_final_from_16bit_sum_avx2(vsse, vsum, sse, &sum);
  573. return *sse - (uint32_t)(((int64_t)sum * sum) >> 8);
  574. }
  575. unsigned int vpx_variance16x32_avx2(const uint8_t *src_ptr, int src_stride,
  576. const uint8_t *ref_ptr, int ref_stride,
  577. unsigned int *sse) {
  578. int sum;
  579. __m256i vsse, vsum;
  580. variance16_avx2(src_ptr, src_stride, ref_ptr, ref_stride, 32, &vsse, &vsum);
  581. variance_final_from_16bit_sum_avx2(vsse, vsum, sse, &sum);
  582. return *sse - (uint32_t)(((int64_t)sum * sum) >> 9);
  583. }
  584. unsigned int vpx_variance32x16_avx2(const uint8_t *src_ptr, int src_stride,
  585. const uint8_t *ref_ptr, int ref_stride,
  586. unsigned int *sse) {
  587. int sum;
  588. __m256i vsse, vsum;
  589. variance32_avx2(src_ptr, src_stride, ref_ptr, ref_stride, 16, &vsse, &vsum);
  590. variance_final_from_16bit_sum_avx2(vsse, vsum, sse, &sum);
  591. return *sse - (uint32_t)(((int64_t)sum * sum) >> 9);
  592. }
  593. unsigned int vpx_variance32x32_avx2(const uint8_t *src_ptr, int src_stride,
  594. const uint8_t *ref_ptr, int ref_stride,
  595. unsigned int *sse) {
  596. int sum;
  597. __m256i vsse, vsum;
  598. __m128i vsum_128;
  599. variance32_avx2(src_ptr, src_stride, ref_ptr, ref_stride, 32, &vsse, &vsum);
  600. vsum_128 = _mm_add_epi16(_mm256_castsi256_si128(vsum),
  601. _mm256_extractf128_si256(vsum, 1));
  602. vsum_128 = _mm_add_epi32(_mm_cvtepi16_epi32(vsum_128),
  603. _mm_cvtepi16_epi32(_mm_srli_si128(vsum_128, 8)));
  604. variance_final_from_32bit_sum_avx2(vsse, vsum_128, sse, &sum);
  605. return *sse - (uint32_t)(((int64_t)sum * sum) >> 10);
  606. }
  607. unsigned int vpx_variance32x64_avx2(const uint8_t *src_ptr, int src_stride,
  608. const uint8_t *ref_ptr, int ref_stride,
  609. unsigned int *sse) {
  610. int sum;
  611. __m256i vsse, vsum;
  612. __m128i vsum_128;
  613. variance32_avx2(src_ptr, src_stride, ref_ptr, ref_stride, 64, &vsse, &vsum);
  614. vsum = sum_to_32bit_avx2(vsum);
  615. vsum_128 = _mm_add_epi32(_mm256_castsi256_si128(vsum),
  616. _mm256_extractf128_si256(vsum, 1));
  617. variance_final_from_32bit_sum_avx2(vsse, vsum_128, sse, &sum);
  618. return *sse - (uint32_t)(((int64_t)sum * sum) >> 11);
  619. }
  620. unsigned int vpx_variance64x32_avx2(const uint8_t *src_ptr, int src_stride,
  621. const uint8_t *ref_ptr, int ref_stride,
  622. unsigned int *sse) {
  623. __m256i vsse = _mm256_setzero_si256();
  624. __m256i vsum = _mm256_setzero_si256();
  625. __m128i vsum_128;
  626. int sum;
  627. variance64_avx2(src_ptr, src_stride, ref_ptr, ref_stride, 32, &vsse, &vsum);
  628. vsum = sum_to_32bit_avx2(vsum);
  629. vsum_128 = _mm_add_epi32(_mm256_castsi256_si128(vsum),
  630. _mm256_extractf128_si256(vsum, 1));
  631. variance_final_from_32bit_sum_avx2(vsse, vsum_128, sse, &sum);
  632. return *sse - (uint32_t)(((int64_t)sum * sum) >> 11);
  633. }
  634. unsigned int vpx_variance64x64_avx2(const uint8_t *src_ptr, int src_stride,
  635. const uint8_t *ref_ptr, int ref_stride,
  636. unsigned int *sse) {
  637. __m256i vsse = _mm256_setzero_si256();
  638. __m256i vsum = _mm256_setzero_si256();
  639. __m128i vsum_128;
  640. int sum;
  641. int i = 0;
  642. for (i = 0; i < 2; i++) {
  643. __m256i vsum16;
  644. variance64_avx2(src_ptr + 32 * i * src_stride, src_stride,
  645. ref_ptr + 32 * i * ref_stride, ref_stride, 32, &vsse,
  646. &vsum16);
  647. vsum = _mm256_add_epi32(vsum, sum_to_32bit_avx2(vsum16));
  648. }
  649. vsum_128 = _mm_add_epi32(_mm256_castsi256_si128(vsum),
  650. _mm256_extractf128_si256(vsum, 1));
  651. variance_final_from_32bit_sum_avx2(vsse, vsum_128, sse, &sum);
  652. return *sse - (unsigned int)(((int64_t)sum * sum) >> 12);
  653. }
  654. unsigned int vpx_mse16x8_avx2(const uint8_t *src_ptr, int src_stride,
  655. const uint8_t *ref_ptr, int ref_stride,
  656. unsigned int *sse) {
  657. int sum;
  658. __m256i vsse, vsum;
  659. variance16_avx2(src_ptr, src_stride, ref_ptr, ref_stride, 8, &vsse, &vsum);
  660. variance_final_from_16bit_sum_avx2(vsse, vsum, sse, &sum);
  661. return *sse;
  662. }
  663. unsigned int vpx_mse16x16_avx2(const uint8_t *src_ptr, int src_stride,
  664. const uint8_t *ref_ptr, int ref_stride,
  665. unsigned int *sse) {
  666. int sum;
  667. __m256i vsse, vsum;
  668. variance16_avx2(src_ptr, src_stride, ref_ptr, ref_stride, 16, &vsse, &vsum);
  669. variance_final_from_16bit_sum_avx2(vsse, vsum, sse, &sum);
  670. return *sse;
  671. }
  672. unsigned int vpx_sub_pixel_variance64x64_avx2(
  673. const uint8_t *src_ptr, int src_stride, int x_offset, int y_offset,
  674. const uint8_t *ref_ptr, int ref_stride, unsigned int *sse) {
  675. unsigned int sse1;
  676. const int se1 = sub_pixel_variance32xh_avx2(
  677. src_ptr, src_stride, x_offset, y_offset, ref_ptr, ref_stride, 64, &sse1);
  678. unsigned int sse2;
  679. const int se2 =
  680. sub_pixel_variance32xh_avx2(src_ptr + 32, src_stride, x_offset, y_offset,
  681. ref_ptr + 32, ref_stride, 64, &sse2);
  682. const int se = se1 + se2;
  683. *sse = sse1 + sse2;
  684. return *sse - (uint32_t)(((int64_t)se * se) >> 12);
  685. }
  686. unsigned int vpx_sub_pixel_variance32x32_avx2(
  687. const uint8_t *src_ptr, int src_stride, int x_offset, int y_offset,
  688. const uint8_t *ref_ptr, int ref_stride, unsigned int *sse) {
  689. const int se = sub_pixel_variance32xh_avx2(
  690. src_ptr, src_stride, x_offset, y_offset, ref_ptr, ref_stride, 32, sse);
  691. return *sse - (uint32_t)(((int64_t)se * se) >> 10);
  692. }
  693. unsigned int vpx_sub_pixel_avg_variance64x64_avx2(
  694. const uint8_t *src_ptr, int src_stride, int x_offset, int y_offset,
  695. const uint8_t *ref_ptr, int ref_stride, unsigned int *sse,
  696. const uint8_t *second_pred) {
  697. unsigned int sse1;
  698. const int se1 = sub_pixel_avg_variance32xh_avx2(src_ptr, src_stride, x_offset,
  699. y_offset, ref_ptr, ref_stride,
  700. second_pred, 64, 64, &sse1);
  701. unsigned int sse2;
  702. const int se2 = sub_pixel_avg_variance32xh_avx2(
  703. src_ptr + 32, src_stride, x_offset, y_offset, ref_ptr + 32, ref_stride,
  704. second_pred + 32, 64, 64, &sse2);
  705. const int se = se1 + se2;
  706. *sse = sse1 + sse2;
  707. return *sse - (uint32_t)(((int64_t)se * se) >> 12);
  708. }
  709. unsigned int vpx_sub_pixel_avg_variance32x32_avx2(
  710. const uint8_t *src_ptr, int src_stride, int x_offset, int y_offset,
  711. const uint8_t *ref_ptr, int ref_stride, unsigned int *sse,
  712. const uint8_t *second_pred) {
  713. // Process 32 elements in parallel.
  714. const int se = sub_pixel_avg_variance32xh_avx2(src_ptr, src_stride, x_offset,
  715. y_offset, ref_ptr, ref_stride,
  716. second_pred, 32, 32, sse);
  717. return *sse - (uint32_t)(((int64_t)se * se) >> 10);
  718. }