IDWT 5x3: generalize SSE2 version for AVX2
authorEven Rouault <even.rouault@spatialys.com>
Wed, 21 Jun 2017 10:12:58 +0000 (12:12 +0200)
committerEven Rouault <even.rouault@spatialys.com>
Wed, 21 Jun 2017 10:12:58 +0000 (12:12 +0200)
Thanks to our macros that abstract SSE use, the functions can use
AVX2 when available (at compile time)

This brings an extra 23% speed improvement on bench_dwt in 64bit builds
with AVX2 compared to SSE2.

src/lib/openjp2/dwt.c
src/lib/openjp2/opj_malloc.c
src/lib/openjp2/opj_malloc.h

index 2af375aaa480df72e6a6d43d58e5bc068824203e..4a5ba609ca95cafbf5fa426eeb9adb809cab2320 100644 (file)
@@ -52,6 +52,9 @@
 #ifdef __SSSE3__
 #include <tmmintrin.h>
 #endif
+#ifdef __AVX2__
+#include <immintrin.h>
+#endif
 
 #if defined(__GNUC__)
 #pragma GCC poison malloc calloc realloc free
 #define OPJ_WS(i) v->mem[(i)*2]
 #define OPJ_WD(i) v->mem[(1+(i)*2)]
 
-#define PARALLEL_COLS_53     8
+#ifdef __AVX2__
+/** Number of int32 values in a AVX2 register */
+#define VREG_INT_COUNT       8
+#else
+/** Number of int32 values in a SSE2 register */
+#define VREG_INT_COUNT       4
+#endif
+
+/** Number of columns that we can process in parallel in the vertical pass */
+#define PARALLEL_COLS_53     (2*VREG_INT_COUNT)
 
 /** @name Local data structures */
 /*@{*/
@@ -553,19 +565,55 @@ static void opj_idwt53_h(const opj_dwt_t *dwt,
 #endif
 }
 
-#if defined(__SSE2__) && !defined(STANDARD_SLOW_VERSION)
+#if (defined(__SSE2__) || defined(__AVX2__)) && !defined(STANDARD_SLOW_VERSION)
 
 /* Conveniency macros to improve the readabilty of the formulas */
-#define LOADU(x)   _mm_loadu_si128((const __m128i*)(x))
-#define STORE(x,y) _mm_store_si128((__m128i*)(x),(y))
-#define ADD(x,y)   _mm_add_epi32((x),(y))
+#if __AVX2__
+#define VREG        __m256i
+#define LOAD_CST(x) _mm256_set1_epi32(x)
+#define LOAD(x)     _mm256_load_si256((const VREG*)(x))
+#define LOADU(x)    _mm256_loadu_si256((const VREG*)(x))
+#define STORE(x,y)  _mm256_store_si256((VREG*)(x),(y))
+#define STOREU(x,y) _mm256_storeu_si256((VREG*)(x),(y))
+#define ADD(x,y)    _mm256_add_epi32((x),(y))
+#define SUB(x,y)    _mm256_sub_epi32((x),(y))
+#define SAR(x,y)    _mm256_srai_epi32((x),(y))
+#else
+#define VREG        __m128i
+#define LOAD_CST(x) _mm_set1_epi32(x)
+#define LOAD(x)     _mm_load_si128((const VREG*)(x))
+#define LOADU(x)    _mm_loadu_si128((const VREG*)(x))
+#define STORE(x,y)  _mm_store_si128((VREG*)(x),(y))
+#define STOREU(x,y) _mm_storeu_si128((VREG*)(x),(y))
+#define ADD(x,y)    _mm_add_epi32((x),(y))
+#define SUB(x,y)    _mm_sub_epi32((x),(y))
+#define SAR(x,y)    _mm_srai_epi32((x),(y))
+#endif
 #define ADD3(x,y,z) ADD(ADD(x,y),z)
-#define SUB(x,y)   _mm_sub_epi32((x),(y))
-#define SAR(x,y)   _mm_srai_epi32((x),(y))
 
-/** Vertical inverse 5x3 wavelet transform for 8 columns, when top-most
- * pixel is on even coordinate */
-static void opj_idwt53_v_cas0_8cols_SSE2(
+static
+void opj_idwt53_v_final_memcpy(OPJ_INT32* tiledp_col,
+                               const OPJ_INT32* tmp,
+                               OPJ_INT32 len,
+                               OPJ_INT32 stride)
+{
+    OPJ_INT32 i;
+    for (i = 0; i < len; ++i) {
+        /* A memcpy(&tiledp_col[i * stride + 0],
+                    &tmp[PARALLEL_COLS_53 * i + 0],
+                    PARALLEL_COLS_53 * sizeof(OPJ_INT32))
+           would do but would be a tiny bit slower.
+           We can take here advantage of our knowledge of alignment */
+        STOREU(&tiledp_col[i * stride + 0],
+               LOAD(&tmp[PARALLEL_COLS_53 * i + 0]));
+        STOREU(&tiledp_col[i * stride + VREG_INT_COUNT],
+               LOAD(&tmp[PARALLEL_COLS_53 * i + VREG_INT_COUNT]));
+    }
+}
+
+/** Vertical inverse 5x3 wavelet transform for 8 columns in SSE2, or
+ * 16 in AVX2, when top-most pixel is on even coordinate */
+static void opj_idwt53_v_cas0_mcols_SSE2_OR_AVX2(
     OPJ_INT32* tmp,
     const OPJ_INT32 sn,
     const OPJ_INT32 len,
@@ -576,17 +624,28 @@ static void opj_idwt53_v_cas0_8cols_SSE2(
     const OPJ_INT32* in_odd = &tiledp_col[sn * stride];
 
     OPJ_INT32 i, j;
-    __m128i d1c_0, d1n_0, s1n_0, s0c_0, s0n_0;
-    __m128i d1c_1, d1n_1, s1n_1, s0c_1, s0n_1;
-    const __m128i two = _mm_set1_epi32(2);
+    VREG d1c_0, d1n_0, s1n_0, s0c_0, s0n_0;
+    VREG d1c_1, d1n_1, s1n_1, s0c_1, s0n_1;
+    const VREG two = LOAD_CST(2);
 
     assert(len > 1);
+#if __AVX2__
+    assert(PARALLEL_COLS_53 == 16);
+    assert(VREG_INT_COUNT == 8);
+#else
     assert(PARALLEL_COLS_53 == 8);
+    assert(VREG_INT_COUNT == 4);
+#endif
+
+    /* Note: loads of input even/odd values must be done in a unaligned */
+    /* fashion. But stores in tmp can be done with aligned store, since */
+    /* the temporary buffer is properly aligned */
+    assert((size_t)tmp % (sizeof(OPJ_INT32) * VREG_INT_COUNT) == 0);
 
     s1n_0 = LOADU(in_even + 0);
-    s1n_1 = LOADU(in_even + 4);
+    s1n_1 = LOADU(in_even + VREG_INT_COUNT);
     d1n_0 = LOADU(in_odd);
-    d1n_1 = LOADU(in_odd + 4);
+    d1n_1 = LOADU(in_odd + VREG_INT_COUNT);
 
     /* s0n = s1n - ((d1n + 1) >> 1); <==> */
     /* s0n = s1n - ((d1n + d1n + 2) >> 2); */
@@ -600,29 +659,29 @@ static void opj_idwt53_v_cas0_8cols_SSE2(
         s0c_1 = s0n_1;
 
         s1n_0 = LOADU(in_even + j * stride);
-        s1n_1 = LOADU(in_even + j * stride + 4);
+        s1n_1 = LOADU(in_even + j * stride + VREG_INT_COUNT);
         d1n_0 = LOADU(in_odd + j * stride);
-        d1n_1 = LOADU(in_odd + j * stride + 4);
+        d1n_1 = LOADU(in_odd + j * stride + VREG_INT_COUNT);
 
         /*s0n = s1n - ((d1c + d1n + 2) >> 2);*/
         s0n_0 = SUB(s1n_0, SAR(ADD3(d1c_0, d1n_0, two), 2));
         s0n_1 = SUB(s1n_1, SAR(ADD3(d1c_1, d1n_1, two), 2));
 
         STORE(tmp + PARALLEL_COLS_53 * (i + 0), s0c_0);
-        STORE(tmp + PARALLEL_COLS_53 * (i + 0) + 4, s0c_1);
+        STORE(tmp + PARALLEL_COLS_53 * (i + 0) + VREG_INT_COUNT, s0c_1);
 
         /* d1c + ((s0c + s0n) >> 1) */
         STORE(tmp + PARALLEL_COLS_53 * (i + 1) + 0,
               ADD(d1c_0, SAR(ADD(s0c_0, s0n_0), 1)));
-        STORE(tmp + PARALLEL_COLS_53 * (i + 1) + 4,
+        STORE(tmp + PARALLEL_COLS_53 * (i + 1) + VREG_INT_COUNT,
               ADD(d1c_1, SAR(ADD(s0c_1, s0n_1), 1)));
     }
 
     STORE(tmp + PARALLEL_COLS_53 * (i + 0) + 0, s0n_0);
-    STORE(tmp + PARALLEL_COLS_53 * (i + 0) + 4, s0n_1);
+    STORE(tmp + PARALLEL_COLS_53 * (i + 0) + VREG_INT_COUNT, s0n_1);
 
     if (len & 1) {
-        __m128i tmp_len_minus_1;
+        VREG tmp_len_minus_1;
         s1n_0 = LOADU(in_even + ((len - 1) / 2) * stride);
         /* tmp_len_minus_1 = s1n - ((d1n + 1) >> 1); */
         tmp_len_minus_1 = SUB(s1n_0, SAR(ADD3(d1n_0, d1n_0, two), 2));
@@ -631,31 +690,30 @@ static void opj_idwt53_v_cas0_8cols_SSE2(
         STORE(tmp + 8 * (len - 2),
               ADD(d1n_0, SAR(ADD(s0n_0, tmp_len_minus_1), 1)));
 
-        s1n_1 = LOADU(in_even + ((len - 1) / 2) * stride + 4);
+        s1n_1 = LOADU(in_even + ((len - 1) / 2) * stride + VREG_INT_COUNT);
         /* tmp_len_minus_1 = s1n - ((d1n + 1) >> 1); */
         tmp_len_minus_1 = SUB(s1n_1, SAR(ADD3(d1n_1, d1n_1, two), 2));
-        STORE(tmp + PARALLEL_COLS_53 * (len - 1) + 4, tmp_len_minus_1);
+        STORE(tmp + PARALLEL_COLS_53 * (len - 1) + VREG_INT_COUNT,
+              tmp_len_minus_1);
         /* d1n + ((s0n + tmp_len_minus_1) >> 1) */
-        STORE(tmp + PARALLEL_COLS_53 * (len - 2) + 4,
+        STORE(tmp + PARALLEL_COLS_53 * (len - 2) + VREG_INT_COUNT,
               ADD(d1n_1, SAR(ADD(s0n_1, tmp_len_minus_1), 1)));
 
 
     } else {
-        STORE(tmp + PARALLEL_COLS_53 * (len - 1) + 0, ADD(d1n_0, s0n_0));
-        STORE(tmp + PARALLEL_COLS_53 * (len - 1) + 4, ADD(d1n_1, s0n_1));
+        STORE(tmp + PARALLEL_COLS_53 * (len - 1) + 0,
+              ADD(d1n_0, s0n_0));
+        STORE(tmp + PARALLEL_COLS_53 * (len - 1) + VREG_INT_COUNT,
+              ADD(d1n_1, s0n_1));
     }
 
-    for (i = 0; i < len; ++i) {
-        memcpy(&tiledp_col[i * stride],
-               &tmp[PARALLEL_COLS_53 * i],
-               PARALLEL_COLS_53 * sizeof(OPJ_INT32));
-    }
+    opj_idwt53_v_final_memcpy(tiledp_col, tmp, len, stride);
 }
 
 
-/** Vertical inverse 5x3 wavelet transform for 8 columns, when top-most
- * pixel is on odd coordinate */
-static void opj_idwt53_v_cas1_8cols_SSE2(
+/** Vertical inverse 5x3 wavelet transform for 8 columns in SSE2, or
+ * 16 in AVX2, when top-most pixel is on odd coordinate */
+static void opj_idwt53_v_cas1_mcols_SSE2_OR_AVX2(
     OPJ_INT32* tmp,
     const OPJ_INT32 sn,
     const OPJ_INT32 len,
@@ -664,15 +722,26 @@ static void opj_idwt53_v_cas1_8cols_SSE2(
 {
     OPJ_INT32 i, j;
 
-    __m128i s1_0, s2_0, dc_0, dn_0;
-    __m128i s1_1, s2_1, dc_1, dn_1;
-    const __m128i two = _mm_set1_epi32(2);
+    VREG s1_0, s2_0, dc_0, dn_0;
+    VREG s1_1, s2_1, dc_1, dn_1;
+    const VREG two = LOAD_CST(2);
 
     const OPJ_INT32* in_even = &tiledp_col[sn * stride];
     const OPJ_INT32* in_odd = &tiledp_col[0];
 
     assert(len > 2);
+#if __AVX2__
+    assert(PARALLEL_COLS_53 == 16);
+    assert(VREG_INT_COUNT == 8);
+#else
     assert(PARALLEL_COLS_53 == 8);
+    assert(VREG_INT_COUNT == 4);
+#endif
+
+    /* Note: loads of input even/odd values must be done in a unaligned */
+    /* fashion. But stores in tmp can be done with aligned store, since */
+    /* the temporary buffer is properly aligned */
+    assert((size_t)tmp % (sizeof(OPJ_INT32) * VREG_INT_COUNT) == 0);
 
     s1_0 = LOADU(in_even + stride);
     /* in_odd[0] - ((in_even[0] + s1 + 2) >> 2); */
@@ -680,30 +749,31 @@ static void opj_idwt53_v_cas1_8cols_SSE2(
                SAR(ADD3(LOADU(in_even + 0), s1_0, two), 2));
     STORE(tmp + PARALLEL_COLS_53 * 0, ADD(LOADU(in_even + 0), dc_0));
 
-    s1_1 = LOADU(in_even + stride + 4);
+    s1_1 = LOADU(in_even + stride + VREG_INT_COUNT);
     /* in_odd[0] - ((in_even[0] + s1 + 2) >> 2); */
-    dc_1 = SUB(LOADU(in_odd + 4),
-               SAR(ADD3(LOADU(in_even + 4), s1_1, two), 2));
-    STORE(tmp + PARALLEL_COLS_53 * 0 + 4, ADD(LOADU(in_even + 4), dc_1));
+    dc_1 = SUB(LOADU(in_odd + VREG_INT_COUNT),
+               SAR(ADD3(LOADU(in_even + VREG_INT_COUNT), s1_1, two), 2));
+    STORE(tmp + PARALLEL_COLS_53 * 0 + VREG_INT_COUNT,
+          ADD(LOADU(in_even + VREG_INT_COUNT), dc_1));
 
     for (i = 1, j = 1; i < (len - 2 - !(len & 1)); i += 2, j++) {
 
         s2_0 = LOADU(in_even + (j + 1) * stride);
-        s2_1 = LOADU(in_even + (j + 1) * stride + 4);
+        s2_1 = LOADU(in_even + (j + 1) * stride + VREG_INT_COUNT);
 
         /* dn = in_odd[j * stride] - ((s1 + s2 + 2) >> 2); */
         dn_0 = SUB(LOADU(in_odd + j * stride),
                    SAR(ADD3(s1_0, s2_0, two), 2));
-        dn_1 = SUB(LOADU(in_odd + j * stride + 4),
+        dn_1 = SUB(LOADU(in_odd + j * stride + VREG_INT_COUNT),
                    SAR(ADD3(s1_1, s2_1, two), 2));
 
         STORE(tmp + PARALLEL_COLS_53 * i, dc_0);
-        STORE(tmp + PARALLEL_COLS_53 * i + 4, dc_1);
+        STORE(tmp + PARALLEL_COLS_53 * i + VREG_INT_COUNT, dc_1);
 
         /* tmp[i + 1] = s1 + ((dn + dc) >> 1); */
         STORE(tmp + PARALLEL_COLS_53 * (i + 1) + 0,
               ADD(s1_0, SAR(ADD(dn_0, dc_0), 1)));
-        STORE(tmp + PARALLEL_COLS_53 * (i + 1) + 4,
+        STORE(tmp + PARALLEL_COLS_53 * (i + 1) + VREG_INT_COUNT,
               ADD(s1_1, SAR(ADD(dn_1, dc_1), 1)));
 
         dc_0 = dn_0;
@@ -712,43 +782,44 @@ static void opj_idwt53_v_cas1_8cols_SSE2(
         s1_1 = s2_1;
     }
     STORE(tmp + PARALLEL_COLS_53 * i, dc_0);
-    STORE(tmp + PARALLEL_COLS_53 * i + 4, dc_1);
+    STORE(tmp + PARALLEL_COLS_53 * i + VREG_INT_COUNT, dc_1);
 
     if (!(len & 1)) {
         /*dn = in_odd[(len / 2 - 1) * stride] - ((s1 + 1) >> 1); */
         dn_0 = SUB(LOADU(in_odd + (len / 2 - 1) * stride),
                    SAR(ADD3(s1_0, s1_0, two), 2));
-        dn_1 = SUB(LOADU(in_odd + (len / 2 - 1) * stride + 4),
+        dn_1 = SUB(LOADU(in_odd + (len / 2 - 1) * stride + VREG_INT_COUNT),
                    SAR(ADD3(s1_1, s1_1, two), 2));
 
         /* tmp[len - 2] = s1 + ((dn + dc) >> 1); */
         STORE(tmp + PARALLEL_COLS_53 * (len - 2) + 0,
               ADD(s1_0, SAR(ADD(dn_0, dc_0), 1)));
-        STORE(tmp + PARALLEL_COLS_53 * (len - 2) + 4,
+        STORE(tmp + PARALLEL_COLS_53 * (len - 2) + VREG_INT_COUNT,
               ADD(s1_1, SAR(ADD(dn_1, dc_1), 1)));
 
         STORE(tmp + PARALLEL_COLS_53 * (len - 1) + 0, dn_0);
-        STORE(tmp + PARALLEL_COLS_53 * (len - 1) + 4, dn_1);
+        STORE(tmp + PARALLEL_COLS_53 * (len - 1) + VREG_INT_COUNT, dn_1);
     } else {
         STORE(tmp + PARALLEL_COLS_53 * (len - 1) + 0, ADD(s1_0, dc_0));
-        STORE(tmp + PARALLEL_COLS_53 * (len - 1) + 4, ADD(s1_1, dc_1));
+        STORE(tmp + PARALLEL_COLS_53 * (len - 1) + VREG_INT_COUNT,
+              ADD(s1_1, dc_1));
     }
 
-    for (i = 0; i < len; ++i) {
-        memcpy(&tiledp_col[i * stride],
-               &tmp[PARALLEL_COLS_53 * i],
-               PARALLEL_COLS_53 * sizeof(OPJ_INT32));
-    }
+    opj_idwt53_v_final_memcpy(tiledp_col, tmp, len, stride);
 }
 
+#undef VREG
+#undef LOAD_CST
 #undef LOADU
+#undef LOAD
 #undef STORE
+#undef STOREU
 #undef ADD
 #undef ADD3
 #undef SUB
 #undef SAR
 
-#endif /* defined(__SSE2__) && !defined(STANDARD_SLOW_VERSION) */
+#endif /* (defined(__SSE2__) || defined(__AVX2__)) && !defined(STANDARD_SLOW_VERSION) */
 
 #if !defined(STANDARD_SLOW_VERSION)
 /** Vertical inverse 5x3 wavelet transform for one column, when top-most
@@ -873,11 +944,11 @@ static void opj_idwt53_v(const opj_dwt_t *dwt,
     if (dwt->cas == 0) {
         /* If len == 1, unmodified value */
 
-#if __SSE2__
+#if (defined(__SSE2__) || defined(__AVX2__))
         if (len > 1 && nb_cols == PARALLEL_COLS_53) {
-            /* Same as below general case, except that thanks to SSE2 */
-            /* we can efficently process 8 columns in parallel */
-            opj_idwt53_v_cas0_8cols_SSE2(dwt->mem, sn, len, tiledp_col, stride);
+            /* Same as below general case, except that thanks to SSE2/AVX2 */
+            /* we can efficently process 8/16 columns in parallel */
+            opj_idwt53_v_cas0_mcols_SSE2_OR_AVX2(dwt->mem, sn, len, tiledp_col, stride);
             return;
         }
 #endif
@@ -916,11 +987,11 @@ static void opj_idwt53_v(const opj_dwt_t *dwt,
             return;
         }
 
-#ifdef __SSE2__
+#if (defined(__SSE2__) || defined(__AVX2__))
         if (len > 2 && nb_cols == PARALLEL_COLS_53) {
-            /* Same as below general case, except that thanks to SSE2 */
-            /* we can efficently process 8 columns in parallel */
-            opj_idwt53_v_cas1_8cols_SSE2(dwt->mem, sn, len, tiledp_col, stride);
+            /* Same as below general case, except that thanks to SSE2/AVX2 */
+            /* we can efficently process 8/16 columns in parallel */
+            opj_idwt53_v_cas1_mcols_SSE2_OR_AVX2(dwt->mem, sn, len, tiledp_col, stride);
             return;
         }
 #endif
@@ -1291,7 +1362,7 @@ static OPJ_BOOL opj_dwt_decode_tile(opj_thread_pool_t* tp,
     /* since for the vertical pass */
     /* we process PARALLEL_COLS_53 columns at a time */
     h_mem_size *= PARALLEL_COLS_53 * sizeof(OPJ_INT32);
-    h.mem = (OPJ_INT32*)opj_aligned_malloc(h_mem_size);
+    h.mem = (OPJ_INT32*)opj_aligned_32_malloc(h_mem_size);
     if (! h.mem) {
         /* FIXME event manager error callback */
         return OPJ_FALSE;
@@ -1348,7 +1419,7 @@ static OPJ_BOOL opj_dwt_decode_tile(opj_thread_pool_t* tp,
                 if (j == (num_jobs - 1U)) {  /* this will take care of the overflow */
                     job->max_j = rh;
                 }
-                job->h.mem = (OPJ_INT32*)opj_aligned_malloc(h_mem_size);
+                job->h.mem = (OPJ_INT32*)opj_aligned_32_malloc(h_mem_size);
                 if (!job->h.mem) {
                     /* FIXME event manager error callback */
                     opj_thread_pool_wait_completion(tp, 0);
@@ -1403,7 +1474,7 @@ static OPJ_BOOL opj_dwt_decode_tile(opj_thread_pool_t* tp,
                 if (j == (num_jobs - 1U)) {  /* this will take care of the overflow */
                     job->max_j = rw;
                 }
-                job->v.mem = (OPJ_INT32*)opj_aligned_malloc(h_mem_size);
+                job->v.mem = (OPJ_INT32*)opj_aligned_32_malloc(h_mem_size);
                 if (!job->v.mem) {
                     /* FIXME event manager error callback */
                     opj_thread_pool_wait_completion(tp, 0);
index 9c438bdb0c6701bdc99cd46f938f24b9829fc2d7..dca91bfcbe9ab39693094b793ee2d2fcc85385df 100644 (file)
@@ -213,6 +213,15 @@ void * opj_aligned_realloc(void *ptr, size_t size)
     return opj_aligned_realloc_n(ptr, 16U, size);
 }
 
+void *opj_aligned_32_malloc(size_t size)
+{
+    return opj_aligned_alloc_n(32U, size);
+}
+void * opj_aligned_32_realloc(void *ptr, size_t size)
+{
+    return opj_aligned_realloc_n(ptr, 32U, size);
+}
+
 void opj_aligned_free(void* ptr)
 {
 #if defined(OPJ_HAVE_POSIX_MEMALIGN) || defined(OPJ_HAVE_MEMALIGN)
index c8c2fc2dbb45403b17610e387973023bebaf027a..7503c28d2a4f18ab5ec588418edd5fef62ae7181 100644 (file)
@@ -71,6 +71,14 @@ void * opj_aligned_malloc(size_t size);
 void * opj_aligned_realloc(void *ptr, size_t size);
 void opj_aligned_free(void* ptr);
 
+/**
+Allocate memory aligned to a 32 byte boundary
+@param size Bytes to allocate
+@return Returns a void pointer to the allocated space, or NULL if there is insufficient memory available
+*/
+void * opj_aligned_32_malloc(size_t size);
+void * opj_aligned_32_realloc(void *ptr, size_t size);
+
 /**
 Reallocate memory blocks.
 @param m Pointer to previously allocated memory block