vp9_diamond_search_sad_avx.c 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310
  1. /*
  2. * Copyright (c) 2015 The WebM project authors. All Rights Reserved.
  3. *
  4. * Use of this source code is governed by a BSD-style license
  5. * that can be found in the LICENSE file in the root of the source
  6. * tree. An additional intellectual property rights grant can be found
  7. * in the file PATENTS. All contributing project authors may
  8. * be found in the AUTHORS file in the root of the source tree.
  9. */
  10. #if defined(_MSC_VER)
  11. #include <intrin.h>
  12. #endif
  13. #include <emmintrin.h>
  14. #include <smmintrin.h>
  15. #include "vpx_dsp/vpx_dsp_common.h"
  16. #include "vp9/encoder/vp9_encoder.h"
  17. #include "vpx_ports/mem.h"
  18. #ifdef __GNUC__
  19. #define LIKELY(v) __builtin_expect(v, 1)
  20. #define UNLIKELY(v) __builtin_expect(v, 0)
  21. #else
  22. #define LIKELY(v) (v)
  23. #define UNLIKELY(v) (v)
  24. #endif
  25. static INLINE int_mv pack_int_mv(int16_t row, int16_t col) {
  26. int_mv result;
  27. result.as_mv.row = row;
  28. result.as_mv.col = col;
  29. return result;
  30. }
  31. static INLINE MV_JOINT_TYPE get_mv_joint(const int_mv mv) {
  32. // This is simplified from the C implementation to utilise that
  33. // x->nmvjointsadcost[1] == x->nmvjointsadcost[2] and
  34. // x->nmvjointsadcost[1] == x->nmvjointsadcost[3]
  35. return mv.as_int == 0 ? 0 : 1;
  36. }
  37. static INLINE int mv_cost(const int_mv mv, const int *joint_cost,
  38. int *const comp_cost[2]) {
  39. return joint_cost[get_mv_joint(mv)] + comp_cost[0][mv.as_mv.row] +
  40. comp_cost[1][mv.as_mv.col];
  41. }
  42. static int mvsad_err_cost(const MACROBLOCK *x, const int_mv mv, const MV *ref,
  43. int sad_per_bit) {
  44. const int_mv diff =
  45. pack_int_mv(mv.as_mv.row - ref->row, mv.as_mv.col - ref->col);
  46. return ROUND_POWER_OF_TWO(
  47. (unsigned)mv_cost(diff, x->nmvjointsadcost, x->nmvsadcost) * sad_per_bit,
  48. VP9_PROB_COST_SHIFT);
  49. }
  50. /*****************************************************************************
  51. * This function utilizes 3 properties of the cost function lookup tables, *
  52. * constructed in using 'cal_nmvjointsadcost' and 'cal_nmvsadcosts' in *
  53. * vp9_encoder.c. *
  54. * For the joint cost: *
  55. * - mvjointsadcost[1] == mvjointsadcost[2] == mvjointsadcost[3] *
  56. * For the component costs: *
  57. * - For all i: mvsadcost[0][i] == mvsadcost[1][i] *
  58. * (Equal costs for both components) *
  59. * - For all i: mvsadcost[0][i] == mvsadcost[0][-i] *
  60. * (Cost function is even) *
  61. * If these do not hold, then this function cannot be used without *
  62. * modification, in which case you can revert to using the C implementation, *
  63. * which does not rely on these properties. *
  64. *****************************************************************************/
  65. int vp9_diamond_search_sad_avx(const MACROBLOCK *x,
  66. const search_site_config *cfg, MV *ref_mv,
  67. MV *best_mv, int search_param, int sad_per_bit,
  68. int *num00, const vp9_variance_fn_ptr_t *fn_ptr,
  69. const MV *center_mv) {
  70. const int_mv maxmv = pack_int_mv(x->mv_limits.row_max, x->mv_limits.col_max);
  71. const __m128i v_max_mv_w = _mm_set1_epi32(maxmv.as_int);
  72. const int_mv minmv = pack_int_mv(x->mv_limits.row_min, x->mv_limits.col_min);
  73. const __m128i v_min_mv_w = _mm_set1_epi32(minmv.as_int);
  74. const __m128i v_spb_d = _mm_set1_epi32(sad_per_bit);
  75. const __m128i v_joint_cost_0_d = _mm_set1_epi32(x->nmvjointsadcost[0]);
  76. const __m128i v_joint_cost_1_d = _mm_set1_epi32(x->nmvjointsadcost[1]);
  77. // search_param determines the length of the initial step and hence the number
  78. // of iterations.
  79. // 0 = initial step (MAX_FIRST_STEP) pel
  80. // 1 = (MAX_FIRST_STEP/2) pel,
  81. // 2 = (MAX_FIRST_STEP/4) pel...
  82. const MV *ss_mv = &cfg->ss_mv[cfg->searches_per_step * search_param];
  83. const intptr_t *ss_os = &cfg->ss_os[cfg->searches_per_step * search_param];
  84. const int tot_steps = cfg->total_steps - search_param;
  85. const int_mv fcenter_mv =
  86. pack_int_mv(center_mv->row >> 3, center_mv->col >> 3);
  87. const __m128i vfcmv = _mm_set1_epi32(fcenter_mv.as_int);
  88. const int ref_row = clamp(ref_mv->row, minmv.as_mv.row, maxmv.as_mv.row);
  89. const int ref_col = clamp(ref_mv->col, minmv.as_mv.col, maxmv.as_mv.col);
  90. int_mv bmv = pack_int_mv(ref_row, ref_col);
  91. int_mv new_bmv = bmv;
  92. __m128i v_bmv_w = _mm_set1_epi32(bmv.as_int);
  93. const int what_stride = x->plane[0].src.stride;
  94. const int in_what_stride = x->e_mbd.plane[0].pre[0].stride;
  95. const uint8_t *const what = x->plane[0].src.buf;
  96. const uint8_t *const in_what =
  97. x->e_mbd.plane[0].pre[0].buf + ref_row * in_what_stride + ref_col;
  98. // Work out the start point for the search
  99. const uint8_t *best_address = in_what;
  100. const uint8_t *new_best_address = best_address;
  101. #if ARCH_X86_64
  102. __m128i v_ba_q = _mm_set1_epi64x((intptr_t)best_address);
  103. #else
  104. __m128i v_ba_d = _mm_set1_epi32((intptr_t)best_address);
  105. #endif
  106. unsigned int best_sad;
  107. int i, j, step;
  108. // Check the prerequisite cost function properties that are easy to check
  109. // in an assert. See the function-level documentation for details on all
  110. // prerequisites.
  111. assert(x->nmvjointsadcost[1] == x->nmvjointsadcost[2]);
  112. assert(x->nmvjointsadcost[1] == x->nmvjointsadcost[3]);
  113. // Check the starting position
  114. best_sad = fn_ptr->sdf(what, what_stride, in_what, in_what_stride);
  115. best_sad += mvsad_err_cost(x, bmv, &fcenter_mv.as_mv, sad_per_bit);
  116. *num00 = 0;
  117. for (i = 0, step = 0; step < tot_steps; step++) {
  118. for (j = 0; j < cfg->searches_per_step; j += 4, i += 4) {
  119. __m128i v_sad_d, v_cost_d, v_outside_d, v_inside_d, v_diff_mv_w;
  120. #if ARCH_X86_64
  121. __m128i v_blocka[2];
  122. #else
  123. __m128i v_blocka[1];
  124. #endif
  125. // Compute the candidate motion vectors
  126. const __m128i v_ss_mv_w = _mm_loadu_si128((const __m128i *)&ss_mv[i]);
  127. const __m128i v_these_mv_w = _mm_add_epi16(v_bmv_w, v_ss_mv_w);
  128. // Clamp them to the search bounds
  129. __m128i v_these_mv_clamp_w = v_these_mv_w;
  130. v_these_mv_clamp_w = _mm_min_epi16(v_these_mv_clamp_w, v_max_mv_w);
  131. v_these_mv_clamp_w = _mm_max_epi16(v_these_mv_clamp_w, v_min_mv_w);
  132. // The ones that did not change are inside the search area
  133. v_inside_d = _mm_cmpeq_epi32(v_these_mv_clamp_w, v_these_mv_w);
  134. // If none of them are inside, then move on
  135. if (LIKELY(_mm_test_all_zeros(v_inside_d, v_inside_d))) {
  136. continue;
  137. }
  138. // The inverse mask indicates which of the MVs are outside
  139. v_outside_d = _mm_xor_si128(v_inside_d, _mm_set1_epi8((int8_t)0xff));
  140. // Shift right to keep the sign bit clear, we will use this later
  141. // to set the cost to the maximum value.
  142. v_outside_d = _mm_srli_epi32(v_outside_d, 1);
  143. // Compute the difference MV
  144. v_diff_mv_w = _mm_sub_epi16(v_these_mv_clamp_w, vfcmv);
  145. // We utilise the fact that the cost function is even, and use the
  146. // absolute difference. This allows us to use unsigned indexes later
  147. // and reduces cache pressure somewhat as only a half of the table
  148. // is ever referenced.
  149. v_diff_mv_w = _mm_abs_epi16(v_diff_mv_w);
  150. // Compute the SIMD pointer offsets.
  151. {
  152. #if ARCH_X86_64 // sizeof(intptr_t) == 8
  153. // Load the offsets
  154. __m128i v_bo10_q = _mm_loadu_si128((const __m128i *)&ss_os[i + 0]);
  155. __m128i v_bo32_q = _mm_loadu_si128((const __m128i *)&ss_os[i + 2]);
  156. // Set the ones falling outside to zero
  157. v_bo10_q = _mm_and_si128(v_bo10_q, _mm_cvtepi32_epi64(v_inside_d));
  158. v_bo32_q =
  159. _mm_and_si128(v_bo32_q, _mm_unpackhi_epi32(v_inside_d, v_inside_d));
  160. // Compute the candidate addresses
  161. v_blocka[0] = _mm_add_epi64(v_ba_q, v_bo10_q);
  162. v_blocka[1] = _mm_add_epi64(v_ba_q, v_bo32_q);
  163. #else // ARCH_X86 // sizeof(intptr_t) == 4
  164. __m128i v_bo_d = _mm_loadu_si128((const __m128i *)&ss_os[i]);
  165. v_bo_d = _mm_and_si128(v_bo_d, v_inside_d);
  166. v_blocka[0] = _mm_add_epi32(v_ba_d, v_bo_d);
  167. #endif
  168. }
  169. fn_ptr->sdx4df(what, what_stride, (const uint8_t **)&v_blocka[0],
  170. in_what_stride, (uint32_t *)&v_sad_d);
  171. // Look up the component cost of the residual motion vector
  172. {
  173. const int32_t row0 = _mm_extract_epi16(v_diff_mv_w, 0);
  174. const int32_t col0 = _mm_extract_epi16(v_diff_mv_w, 1);
  175. const int32_t row1 = _mm_extract_epi16(v_diff_mv_w, 2);
  176. const int32_t col1 = _mm_extract_epi16(v_diff_mv_w, 3);
  177. const int32_t row2 = _mm_extract_epi16(v_diff_mv_w, 4);
  178. const int32_t col2 = _mm_extract_epi16(v_diff_mv_w, 5);
  179. const int32_t row3 = _mm_extract_epi16(v_diff_mv_w, 6);
  180. const int32_t col3 = _mm_extract_epi16(v_diff_mv_w, 7);
  181. // Note: This is a use case for vpgather in AVX2
  182. const uint32_t cost0 = x->nmvsadcost[0][row0] + x->nmvsadcost[0][col0];
  183. const uint32_t cost1 = x->nmvsadcost[0][row1] + x->nmvsadcost[0][col1];
  184. const uint32_t cost2 = x->nmvsadcost[0][row2] + x->nmvsadcost[0][col2];
  185. const uint32_t cost3 = x->nmvsadcost[0][row3] + x->nmvsadcost[0][col3];
  186. __m128i v_cost_10_d, v_cost_32_d;
  187. v_cost_10_d = _mm_cvtsi32_si128(cost0);
  188. v_cost_10_d = _mm_insert_epi32(v_cost_10_d, cost1, 1);
  189. v_cost_32_d = _mm_cvtsi32_si128(cost2);
  190. v_cost_32_d = _mm_insert_epi32(v_cost_32_d, cost3, 1);
  191. v_cost_d = _mm_unpacklo_epi64(v_cost_10_d, v_cost_32_d);
  192. }
  193. // Now add in the joint cost
  194. {
  195. const __m128i v_sel_d =
  196. _mm_cmpeq_epi32(v_diff_mv_w, _mm_setzero_si128());
  197. const __m128i v_joint_cost_d =
  198. _mm_blendv_epi8(v_joint_cost_1_d, v_joint_cost_0_d, v_sel_d);
  199. v_cost_d = _mm_add_epi32(v_cost_d, v_joint_cost_d);
  200. }
  201. // Multiply by sad_per_bit
  202. v_cost_d = _mm_mullo_epi32(v_cost_d, v_spb_d);
  203. // ROUND_POWER_OF_TWO(v_cost_d, VP9_PROB_COST_SHIFT)
  204. v_cost_d = _mm_add_epi32(v_cost_d,
  205. _mm_set1_epi32(1 << (VP9_PROB_COST_SHIFT - 1)));
  206. v_cost_d = _mm_srai_epi32(v_cost_d, VP9_PROB_COST_SHIFT);
  207. // Add the cost to the sad
  208. v_sad_d = _mm_add_epi32(v_sad_d, v_cost_d);
  209. // Make the motion vectors outside the search area have max cost
  210. // by or'ing in the comparison mask, this way the minimum search won't
  211. // pick them.
  212. v_sad_d = _mm_or_si128(v_sad_d, v_outside_d);
  213. // Find the minimum value and index horizontally in v_sad_d
  214. {
  215. // Try speculatively on 16 bits, so we can use the minpos intrinsic
  216. const __m128i v_sad_w = _mm_packus_epi32(v_sad_d, v_sad_d);
  217. const __m128i v_minp_w = _mm_minpos_epu16(v_sad_w);
  218. uint32_t local_best_sad = _mm_extract_epi16(v_minp_w, 0);
  219. uint32_t local_best_idx = _mm_extract_epi16(v_minp_w, 1);
  220. // If the local best value is not saturated, just use it, otherwise
  221. // find the horizontal minimum again the hard way on 32 bits.
  222. // This is executed rarely.
  223. if (UNLIKELY(local_best_sad == 0xffff)) {
  224. __m128i v_loval_d, v_hival_d, v_loidx_d, v_hiidx_d, v_sel_d;
  225. v_loval_d = v_sad_d;
  226. v_loidx_d = _mm_set_epi32(3, 2, 1, 0);
  227. v_hival_d = _mm_srli_si128(v_loval_d, 8);
  228. v_hiidx_d = _mm_srli_si128(v_loidx_d, 8);
  229. v_sel_d = _mm_cmplt_epi32(v_hival_d, v_loval_d);
  230. v_loval_d = _mm_blendv_epi8(v_loval_d, v_hival_d, v_sel_d);
  231. v_loidx_d = _mm_blendv_epi8(v_loidx_d, v_hiidx_d, v_sel_d);
  232. v_hival_d = _mm_srli_si128(v_loval_d, 4);
  233. v_hiidx_d = _mm_srli_si128(v_loidx_d, 4);
  234. v_sel_d = _mm_cmplt_epi32(v_hival_d, v_loval_d);
  235. v_loval_d = _mm_blendv_epi8(v_loval_d, v_hival_d, v_sel_d);
  236. v_loidx_d = _mm_blendv_epi8(v_loidx_d, v_hiidx_d, v_sel_d);
  237. local_best_sad = _mm_extract_epi32(v_loval_d, 0);
  238. local_best_idx = _mm_extract_epi32(v_loidx_d, 0);
  239. }
  240. // Update the global minimum if the local minimum is smaller
  241. if (LIKELY(local_best_sad < best_sad)) {
  242. new_bmv = ((const int_mv *)&v_these_mv_w)[local_best_idx];
  243. new_best_address = ((const uint8_t **)v_blocka)[local_best_idx];
  244. best_sad = local_best_sad;
  245. }
  246. }
  247. }
  248. bmv = new_bmv;
  249. best_address = new_best_address;
  250. v_bmv_w = _mm_set1_epi32(bmv.as_int);
  251. #if ARCH_X86_64
  252. v_ba_q = _mm_set1_epi64x((intptr_t)best_address);
  253. #else
  254. v_ba_d = _mm_set1_epi32((intptr_t)best_address);
  255. #endif
  256. if (UNLIKELY(best_address == in_what)) {
  257. (*num00)++;
  258. }
  259. }
  260. *best_mv = bmv.as_mv;
  261. return best_sad;
  262. }