From: Even Rouault Date: Wed, 21 Jun 2017 10:12:58 +0000 (+0200) Subject: IDWT 5x3: generalize SSE2 version for AVX2 X-Git-Tag: v2.2.0~71^2~1 X-Git-Url: https://git.carlh.net/gitweb/?a=commitdiff_plain;h=fd0dc535ad9ae0d369d1039aaf56235583ca64ea;p=openjpeg.git IDWT 5x3: generalize SSE2 version for AVX2 Thanks to our macros that abstract SSE use, the functions can use AVX2 when available (at compile time) This brings an extra 23% speed improvement on bench_dwt in 64bit builds with AVX2 compared to SSE2. --- diff --git a/src/lib/openjp2/dwt.c b/src/lib/openjp2/dwt.c index 2af375aa..4a5ba609 100644 --- a/src/lib/openjp2/dwt.c +++ b/src/lib/openjp2/dwt.c @@ -52,6 +52,9 @@ #ifdef __SSSE3__ #include #endif +#ifdef __AVX2__ +#include +#endif #if defined(__GNUC__) #pragma GCC poison malloc calloc realloc free @@ -63,7 +66,16 @@ #define OPJ_WS(i) v->mem[(i)*2] #define OPJ_WD(i) v->mem[(1+(i)*2)] -#define PARALLEL_COLS_53 8 +#ifdef __AVX2__ +/** Number of int32 values in a AVX2 register */ +#define VREG_INT_COUNT 8 +#else +/** Number of int32 values in a SSE2 register */ +#define VREG_INT_COUNT 4 +#endif + +/** Number of columns that we can process in parallel in the vertical pass */ +#define PARALLEL_COLS_53 (2*VREG_INT_COUNT) /** @name Local data structures */ /*@{*/ @@ -553,19 +565,55 @@ static void opj_idwt53_h(const opj_dwt_t *dwt, #endif } -#if defined(__SSE2__) && !defined(STANDARD_SLOW_VERSION) +#if (defined(__SSE2__) || defined(__AVX2__)) && !defined(STANDARD_SLOW_VERSION) /* Conveniency macros to improve the readabilty of the formulas */ -#define LOADU(x) _mm_loadu_si128((const __m128i*)(x)) -#define STORE(x,y) _mm_store_si128((__m128i*)(x),(y)) -#define ADD(x,y) _mm_add_epi32((x),(y)) +#if __AVX2__ +#define VREG __m256i +#define LOAD_CST(x) _mm256_set1_epi32(x) +#define LOAD(x) _mm256_load_si256((const VREG*)(x)) +#define LOADU(x) _mm256_loadu_si256((const VREG*)(x)) +#define STORE(x,y) _mm256_store_si256((VREG*)(x),(y)) +#define STOREU(x,y) _mm256_storeu_si256((VREG*)(x),(y)) +#define ADD(x,y) _mm256_add_epi32((x),(y)) +#define SUB(x,y) _mm256_sub_epi32((x),(y)) +#define SAR(x,y) _mm256_srai_epi32((x),(y)) +#else +#define VREG __m128i +#define LOAD_CST(x) _mm_set1_epi32(x) +#define LOAD(x) _mm_load_si128((const VREG*)(x)) +#define LOADU(x) _mm_loadu_si128((const VREG*)(x)) +#define STORE(x,y) _mm_store_si128((VREG*)(x),(y)) +#define STOREU(x,y) _mm_storeu_si128((VREG*)(x),(y)) +#define ADD(x,y) _mm_add_epi32((x),(y)) +#define SUB(x,y) _mm_sub_epi32((x),(y)) +#define SAR(x,y) _mm_srai_epi32((x),(y)) +#endif #define ADD3(x,y,z) ADD(ADD(x,y),z) -#define SUB(x,y) _mm_sub_epi32((x),(y)) -#define SAR(x,y) _mm_srai_epi32((x),(y)) -/** Vertical inverse 5x3 wavelet transform for 8 columns, when top-most - * pixel is on even coordinate */ -static void opj_idwt53_v_cas0_8cols_SSE2( +static +void opj_idwt53_v_final_memcpy(OPJ_INT32* tiledp_col, + const OPJ_INT32* tmp, + OPJ_INT32 len, + OPJ_INT32 stride) +{ + OPJ_INT32 i; + for (i = 0; i < len; ++i) { + /* A memcpy(&tiledp_col[i * stride + 0], + &tmp[PARALLEL_COLS_53 * i + 0], + PARALLEL_COLS_53 * sizeof(OPJ_INT32)) + would do but would be a tiny bit slower. + We can take here advantage of our knowledge of alignment */ + STOREU(&tiledp_col[i * stride + 0], + LOAD(&tmp[PARALLEL_COLS_53 * i + 0])); + STOREU(&tiledp_col[i * stride + VREG_INT_COUNT], + LOAD(&tmp[PARALLEL_COLS_53 * i + VREG_INT_COUNT])); + } +} + +/** Vertical inverse 5x3 wavelet transform for 8 columns in SSE2, or + * 16 in AVX2, when top-most pixel is on even coordinate */ +static void opj_idwt53_v_cas0_mcols_SSE2_OR_AVX2( OPJ_INT32* tmp, const OPJ_INT32 sn, const OPJ_INT32 len, @@ -576,17 +624,28 @@ static void opj_idwt53_v_cas0_8cols_SSE2( const OPJ_INT32* in_odd = &tiledp_col[sn * stride]; OPJ_INT32 i, j; - __m128i d1c_0, d1n_0, s1n_0, s0c_0, s0n_0; - __m128i d1c_1, d1n_1, s1n_1, s0c_1, s0n_1; - const __m128i two = _mm_set1_epi32(2); + VREG d1c_0, d1n_0, s1n_0, s0c_0, s0n_0; + VREG d1c_1, d1n_1, s1n_1, s0c_1, s0n_1; + const VREG two = LOAD_CST(2); assert(len > 1); +#if __AVX2__ + assert(PARALLEL_COLS_53 == 16); + assert(VREG_INT_COUNT == 8); +#else assert(PARALLEL_COLS_53 == 8); + assert(VREG_INT_COUNT == 4); +#endif + + /* Note: loads of input even/odd values must be done in a unaligned */ + /* fashion. But stores in tmp can be done with aligned store, since */ + /* the temporary buffer is properly aligned */ + assert((size_t)tmp % (sizeof(OPJ_INT32) * VREG_INT_COUNT) == 0); s1n_0 = LOADU(in_even + 0); - s1n_1 = LOADU(in_even + 4); + s1n_1 = LOADU(in_even + VREG_INT_COUNT); d1n_0 = LOADU(in_odd); - d1n_1 = LOADU(in_odd + 4); + d1n_1 = LOADU(in_odd + VREG_INT_COUNT); /* s0n = s1n - ((d1n + 1) >> 1); <==> */ /* s0n = s1n - ((d1n + d1n + 2) >> 2); */ @@ -600,29 +659,29 @@ static void opj_idwt53_v_cas0_8cols_SSE2( s0c_1 = s0n_1; s1n_0 = LOADU(in_even + j * stride); - s1n_1 = LOADU(in_even + j * stride + 4); + s1n_1 = LOADU(in_even + j * stride + VREG_INT_COUNT); d1n_0 = LOADU(in_odd + j * stride); - d1n_1 = LOADU(in_odd + j * stride + 4); + d1n_1 = LOADU(in_odd + j * stride + VREG_INT_COUNT); /*s0n = s1n - ((d1c + d1n + 2) >> 2);*/ s0n_0 = SUB(s1n_0, SAR(ADD3(d1c_0, d1n_0, two), 2)); s0n_1 = SUB(s1n_1, SAR(ADD3(d1c_1, d1n_1, two), 2)); STORE(tmp + PARALLEL_COLS_53 * (i + 0), s0c_0); - STORE(tmp + PARALLEL_COLS_53 * (i + 0) + 4, s0c_1); + STORE(tmp + PARALLEL_COLS_53 * (i + 0) + VREG_INT_COUNT, s0c_1); /* d1c + ((s0c + s0n) >> 1) */ STORE(tmp + PARALLEL_COLS_53 * (i + 1) + 0, ADD(d1c_0, SAR(ADD(s0c_0, s0n_0), 1))); - STORE(tmp + PARALLEL_COLS_53 * (i + 1) + 4, + STORE(tmp + PARALLEL_COLS_53 * (i + 1) + VREG_INT_COUNT, ADD(d1c_1, SAR(ADD(s0c_1, s0n_1), 1))); } STORE(tmp + PARALLEL_COLS_53 * (i + 0) + 0, s0n_0); - STORE(tmp + PARALLEL_COLS_53 * (i + 0) + 4, s0n_1); + STORE(tmp + PARALLEL_COLS_53 * (i + 0) + VREG_INT_COUNT, s0n_1); if (len & 1) { - __m128i tmp_len_minus_1; + VREG tmp_len_minus_1; s1n_0 = LOADU(in_even + ((len - 1) / 2) * stride); /* tmp_len_minus_1 = s1n - ((d1n + 1) >> 1); */ tmp_len_minus_1 = SUB(s1n_0, SAR(ADD3(d1n_0, d1n_0, two), 2)); @@ -631,31 +690,30 @@ static void opj_idwt53_v_cas0_8cols_SSE2( STORE(tmp + 8 * (len - 2), ADD(d1n_0, SAR(ADD(s0n_0, tmp_len_minus_1), 1))); - s1n_1 = LOADU(in_even + ((len - 1) / 2) * stride + 4); + s1n_1 = LOADU(in_even + ((len - 1) / 2) * stride + VREG_INT_COUNT); /* tmp_len_minus_1 = s1n - ((d1n + 1) >> 1); */ tmp_len_minus_1 = SUB(s1n_1, SAR(ADD3(d1n_1, d1n_1, two), 2)); - STORE(tmp + PARALLEL_COLS_53 * (len - 1) + 4, tmp_len_minus_1); + STORE(tmp + PARALLEL_COLS_53 * (len - 1) + VREG_INT_COUNT, + tmp_len_minus_1); /* d1n + ((s0n + tmp_len_minus_1) >> 1) */ - STORE(tmp + PARALLEL_COLS_53 * (len - 2) + 4, + STORE(tmp + PARALLEL_COLS_53 * (len - 2) + VREG_INT_COUNT, ADD(d1n_1, SAR(ADD(s0n_1, tmp_len_minus_1), 1))); } else { - STORE(tmp + PARALLEL_COLS_53 * (len - 1) + 0, ADD(d1n_0, s0n_0)); - STORE(tmp + PARALLEL_COLS_53 * (len - 1) + 4, ADD(d1n_1, s0n_1)); + STORE(tmp + PARALLEL_COLS_53 * (len - 1) + 0, + ADD(d1n_0, s0n_0)); + STORE(tmp + PARALLEL_COLS_53 * (len - 1) + VREG_INT_COUNT, + ADD(d1n_1, s0n_1)); } - for (i = 0; i < len; ++i) { - memcpy(&tiledp_col[i * stride], - &tmp[PARALLEL_COLS_53 * i], - PARALLEL_COLS_53 * sizeof(OPJ_INT32)); - } + opj_idwt53_v_final_memcpy(tiledp_col, tmp, len, stride); } -/** Vertical inverse 5x3 wavelet transform for 8 columns, when top-most - * pixel is on odd coordinate */ -static void opj_idwt53_v_cas1_8cols_SSE2( +/** Vertical inverse 5x3 wavelet transform for 8 columns in SSE2, or + * 16 in AVX2, when top-most pixel is on odd coordinate */ +static void opj_idwt53_v_cas1_mcols_SSE2_OR_AVX2( OPJ_INT32* tmp, const OPJ_INT32 sn, const OPJ_INT32 len, @@ -664,15 +722,26 @@ static void opj_idwt53_v_cas1_8cols_SSE2( { OPJ_INT32 i, j; - __m128i s1_0, s2_0, dc_0, dn_0; - __m128i s1_1, s2_1, dc_1, dn_1; - const __m128i two = _mm_set1_epi32(2); + VREG s1_0, s2_0, dc_0, dn_0; + VREG s1_1, s2_1, dc_1, dn_1; + const VREG two = LOAD_CST(2); const OPJ_INT32* in_even = &tiledp_col[sn * stride]; const OPJ_INT32* in_odd = &tiledp_col[0]; assert(len > 2); +#if __AVX2__ + assert(PARALLEL_COLS_53 == 16); + assert(VREG_INT_COUNT == 8); +#else assert(PARALLEL_COLS_53 == 8); + assert(VREG_INT_COUNT == 4); +#endif + + /* Note: loads of input even/odd values must be done in a unaligned */ + /* fashion. But stores in tmp can be done with aligned store, since */ + /* the temporary buffer is properly aligned */ + assert((size_t)tmp % (sizeof(OPJ_INT32) * VREG_INT_COUNT) == 0); s1_0 = LOADU(in_even + stride); /* in_odd[0] - ((in_even[0] + s1 + 2) >> 2); */ @@ -680,30 +749,31 @@ static void opj_idwt53_v_cas1_8cols_SSE2( SAR(ADD3(LOADU(in_even + 0), s1_0, two), 2)); STORE(tmp + PARALLEL_COLS_53 * 0, ADD(LOADU(in_even + 0), dc_0)); - s1_1 = LOADU(in_even + stride + 4); + s1_1 = LOADU(in_even + stride + VREG_INT_COUNT); /* in_odd[0] - ((in_even[0] + s1 + 2) >> 2); */ - dc_1 = SUB(LOADU(in_odd + 4), - SAR(ADD3(LOADU(in_even + 4), s1_1, two), 2)); - STORE(tmp + PARALLEL_COLS_53 * 0 + 4, ADD(LOADU(in_even + 4), dc_1)); + dc_1 = SUB(LOADU(in_odd + VREG_INT_COUNT), + SAR(ADD3(LOADU(in_even + VREG_INT_COUNT), s1_1, two), 2)); + STORE(tmp + PARALLEL_COLS_53 * 0 + VREG_INT_COUNT, + ADD(LOADU(in_even + VREG_INT_COUNT), dc_1)); for (i = 1, j = 1; i < (len - 2 - !(len & 1)); i += 2, j++) { s2_0 = LOADU(in_even + (j + 1) * stride); - s2_1 = LOADU(in_even + (j + 1) * stride + 4); + s2_1 = LOADU(in_even + (j + 1) * stride + VREG_INT_COUNT); /* dn = in_odd[j * stride] - ((s1 + s2 + 2) >> 2); */ dn_0 = SUB(LOADU(in_odd + j * stride), SAR(ADD3(s1_0, s2_0, two), 2)); - dn_1 = SUB(LOADU(in_odd + j * stride + 4), + dn_1 = SUB(LOADU(in_odd + j * stride + VREG_INT_COUNT), SAR(ADD3(s1_1, s2_1, two), 2)); STORE(tmp + PARALLEL_COLS_53 * i, dc_0); - STORE(tmp + PARALLEL_COLS_53 * i + 4, dc_1); + STORE(tmp + PARALLEL_COLS_53 * i + VREG_INT_COUNT, dc_1); /* tmp[i + 1] = s1 + ((dn + dc) >> 1); */ STORE(tmp + PARALLEL_COLS_53 * (i + 1) + 0, ADD(s1_0, SAR(ADD(dn_0, dc_0), 1))); - STORE(tmp + PARALLEL_COLS_53 * (i + 1) + 4, + STORE(tmp + PARALLEL_COLS_53 * (i + 1) + VREG_INT_COUNT, ADD(s1_1, SAR(ADD(dn_1, dc_1), 1))); dc_0 = dn_0; @@ -712,43 +782,44 @@ static void opj_idwt53_v_cas1_8cols_SSE2( s1_1 = s2_1; } STORE(tmp + PARALLEL_COLS_53 * i, dc_0); - STORE(tmp + PARALLEL_COLS_53 * i + 4, dc_1); + STORE(tmp + PARALLEL_COLS_53 * i + VREG_INT_COUNT, dc_1); if (!(len & 1)) { /*dn = in_odd[(len / 2 - 1) * stride] - ((s1 + 1) >> 1); */ dn_0 = SUB(LOADU(in_odd + (len / 2 - 1) * stride), SAR(ADD3(s1_0, s1_0, two), 2)); - dn_1 = SUB(LOADU(in_odd + (len / 2 - 1) * stride + 4), + dn_1 = SUB(LOADU(in_odd + (len / 2 - 1) * stride + VREG_INT_COUNT), SAR(ADD3(s1_1, s1_1, two), 2)); /* tmp[len - 2] = s1 + ((dn + dc) >> 1); */ STORE(tmp + PARALLEL_COLS_53 * (len - 2) + 0, ADD(s1_0, SAR(ADD(dn_0, dc_0), 1))); - STORE(tmp + PARALLEL_COLS_53 * (len - 2) + 4, + STORE(tmp + PARALLEL_COLS_53 * (len - 2) + VREG_INT_COUNT, ADD(s1_1, SAR(ADD(dn_1, dc_1), 1))); STORE(tmp + PARALLEL_COLS_53 * (len - 1) + 0, dn_0); - STORE(tmp + PARALLEL_COLS_53 * (len - 1) + 4, dn_1); + STORE(tmp + PARALLEL_COLS_53 * (len - 1) + VREG_INT_COUNT, dn_1); } else { STORE(tmp + PARALLEL_COLS_53 * (len - 1) + 0, ADD(s1_0, dc_0)); - STORE(tmp + PARALLEL_COLS_53 * (len - 1) + 4, ADD(s1_1, dc_1)); + STORE(tmp + PARALLEL_COLS_53 * (len - 1) + VREG_INT_COUNT, + ADD(s1_1, dc_1)); } - for (i = 0; i < len; ++i) { - memcpy(&tiledp_col[i * stride], - &tmp[PARALLEL_COLS_53 * i], - PARALLEL_COLS_53 * sizeof(OPJ_INT32)); - } + opj_idwt53_v_final_memcpy(tiledp_col, tmp, len, stride); } +#undef VREG +#undef LOAD_CST #undef LOADU +#undef LOAD #undef STORE +#undef STOREU #undef ADD #undef ADD3 #undef SUB #undef SAR -#endif /* defined(__SSE2__) && !defined(STANDARD_SLOW_VERSION) */ +#endif /* (defined(__SSE2__) || defined(__AVX2__)) && !defined(STANDARD_SLOW_VERSION) */ #if !defined(STANDARD_SLOW_VERSION) /** Vertical inverse 5x3 wavelet transform for one column, when top-most @@ -873,11 +944,11 @@ static void opj_idwt53_v(const opj_dwt_t *dwt, if (dwt->cas == 0) { /* If len == 1, unmodified value */ -#if __SSE2__ +#if (defined(__SSE2__) || defined(__AVX2__)) if (len > 1 && nb_cols == PARALLEL_COLS_53) { - /* Same as below general case, except that thanks to SSE2 */ - /* we can efficently process 8 columns in parallel */ - opj_idwt53_v_cas0_8cols_SSE2(dwt->mem, sn, len, tiledp_col, stride); + /* Same as below general case, except that thanks to SSE2/AVX2 */ + /* we can efficently process 8/16 columns in parallel */ + opj_idwt53_v_cas0_mcols_SSE2_OR_AVX2(dwt->mem, sn, len, tiledp_col, stride); return; } #endif @@ -916,11 +987,11 @@ static void opj_idwt53_v(const opj_dwt_t *dwt, return; } -#ifdef __SSE2__ +#if (defined(__SSE2__) || defined(__AVX2__)) if (len > 2 && nb_cols == PARALLEL_COLS_53) { - /* Same as below general case, except that thanks to SSE2 */ - /* we can efficently process 8 columns in parallel */ - opj_idwt53_v_cas1_8cols_SSE2(dwt->mem, sn, len, tiledp_col, stride); + /* Same as below general case, except that thanks to SSE2/AVX2 */ + /* we can efficently process 8/16 columns in parallel */ + opj_idwt53_v_cas1_mcols_SSE2_OR_AVX2(dwt->mem, sn, len, tiledp_col, stride); return; } #endif @@ -1291,7 +1362,7 @@ static OPJ_BOOL opj_dwt_decode_tile(opj_thread_pool_t* tp, /* since for the vertical pass */ /* we process PARALLEL_COLS_53 columns at a time */ h_mem_size *= PARALLEL_COLS_53 * sizeof(OPJ_INT32); - h.mem = (OPJ_INT32*)opj_aligned_malloc(h_mem_size); + h.mem = (OPJ_INT32*)opj_aligned_32_malloc(h_mem_size); if (! h.mem) { /* FIXME event manager error callback */ return OPJ_FALSE; @@ -1348,7 +1419,7 @@ static OPJ_BOOL opj_dwt_decode_tile(opj_thread_pool_t* tp, if (j == (num_jobs - 1U)) { /* this will take care of the overflow */ job->max_j = rh; } - job->h.mem = (OPJ_INT32*)opj_aligned_malloc(h_mem_size); + job->h.mem = (OPJ_INT32*)opj_aligned_32_malloc(h_mem_size); if (!job->h.mem) { /* FIXME event manager error callback */ opj_thread_pool_wait_completion(tp, 0); @@ -1403,7 +1474,7 @@ static OPJ_BOOL opj_dwt_decode_tile(opj_thread_pool_t* tp, if (j == (num_jobs - 1U)) { /* this will take care of the overflow */ job->max_j = rw; } - job->v.mem = (OPJ_INT32*)opj_aligned_malloc(h_mem_size); + job->v.mem = (OPJ_INT32*)opj_aligned_32_malloc(h_mem_size); if (!job->v.mem) { /* FIXME event manager error callback */ opj_thread_pool_wait_completion(tp, 0); diff --git a/src/lib/openjp2/opj_malloc.c b/src/lib/openjp2/opj_malloc.c index 9c438bdb..dca91bfc 100644 --- a/src/lib/openjp2/opj_malloc.c +++ b/src/lib/openjp2/opj_malloc.c @@ -213,6 +213,15 @@ void * opj_aligned_realloc(void *ptr, size_t size) return opj_aligned_realloc_n(ptr, 16U, size); } +void *opj_aligned_32_malloc(size_t size) +{ + return opj_aligned_alloc_n(32U, size); +} +void * opj_aligned_32_realloc(void *ptr, size_t size) +{ + return opj_aligned_realloc_n(ptr, 32U, size); +} + void opj_aligned_free(void* ptr) { #if defined(OPJ_HAVE_POSIX_MEMALIGN) || defined(OPJ_HAVE_MEMALIGN) diff --git a/src/lib/openjp2/opj_malloc.h b/src/lib/openjp2/opj_malloc.h index c8c2fc2d..7503c28d 100644 --- a/src/lib/openjp2/opj_malloc.h +++ b/src/lib/openjp2/opj_malloc.h @@ -71,6 +71,14 @@ void * opj_aligned_malloc(size_t size); void * opj_aligned_realloc(void *ptr, size_t size); void opj_aligned_free(void* ptr); +/** +Allocate memory aligned to a 32 byte boundary +@param size Bytes to allocate +@return Returns a void pointer to the allocated space, or NULL if there is insufficient memory available +*/ +void * opj_aligned_32_malloc(size_t size); +void * opj_aligned_32_realloc(void *ptr, size_t size); + /** Reallocate memory blocks. @param m Pointer to previously allocated memory block