diff --git a/gromacs/includes/simd/avx512_double.h b/gromacs/includes/simd/avx512_double.h index 18d67eb..3777c0c 100644 --- a/gromacs/includes/simd/avx512_double.h +++ b/gromacs/includes/simd/avx512_double.h @@ -40,6 +40,7 @@ static inline MD_SIMD_MASK simd_mask_cond_lt(MD_SIMD_FLOAT a, MD_SIMD_FLOAT b) { static inline MD_SIMD_MASK simd_mask_from_u32(unsigned int a) { return _cvtu32_mask8(a); } static inline unsigned int simd_mask_to_u32(MD_SIMD_MASK a) { return _cvtmask8_u32(a); } static inline MD_SIMD_FLOAT simd_load(MD_FLOAT *p) { return _mm512_load_pd(p); } +static inline MD_SIMD_FLOAT select_by_mask(MD_SIMD_FLOAT a, MD_SIMD_MASK m) { return _mm512_mask_mov_pd(_mm512_setzero_pd(), m, a); } static inline MD_FLOAT simd_h_reduce_sum(MD_SIMD_FLOAT a) { MD_SIMD_FLOAT x = _mm512_add_pd(a, _mm512_shuffle_f64x2(a, a, 0xee)); x = _mm512_add_pd(x, _mm512_shuffle_f64x2(x, x, 0x11)); @@ -94,3 +95,17 @@ static inline MD_FLOAT simd_h_dual_incr_reduced_sum(MD_FLOAT *m, MD_SIMD_FLOAT v t0 = _mm512_add_pd(t0, _mm512_permutex_pd(t0, 0xb1)); return _mm_cvtsd_f64(_mm512_castpd512_pd128(t0)); } + +inline void simd_h_decr(MD_FLOAT *m, MD_SIMD_FLOAT a) { + __m256d t; + a = _mm512_add_pd(a, _mm512_shuffle_f64x2(a, a, 0xee)); + t = _mm256_load_pd(m); + t = _mm256_sub_pd(t, _mm512_castpd512_pd256(a)); + _mm256_store_pd(m, t); +} + +static inline void simd_h_decr3(MD_FLOAT *m, MD_SIMD_FLOAT a0, MD_SIMD_FLOAT a1, MD_SIMD_FLOAT a2) { + simd_h_decr(m, a0); + simd_h_decr(m + CLUSTER_N, a1); + simd_h_decr(m + CLUSTER_N * 2, a2); +} diff --git a/gromacs/includes/simd/avx512_float.h b/gromacs/includes/simd/avx512_float.h index 104ad2c..1f7c0ee 100644 --- a/gromacs/includes/simd/avx512_float.h +++ b/gromacs/includes/simd/avx512_float.h @@ -95,6 +95,6 @@ inline void simd_h_decr(MD_FLOAT *m, MD_SIMD_FLOAT a) { static inline void simd_h_decr3(MD_FLOAT *m, MD_SIMD_FLOAT a0, MD_SIMD_FLOAT a1, MD_SIMD_FLOAT a2) { simd_h_decr(m, a0); - simd_h_decr(m + CLUSTER_M, a1); - simd_h_decr(m + CLUSTER_M * 2, a2); + simd_h_decr(m + CLUSTER_N, a1); + simd_h_decr(m + CLUSTER_N * 2, a2); }