Pre-compute masks for 4xn kernels

Signed-off-by: Rafael Ravedutti <rafaelravedutti@gmail.com>
This commit is contained in:
Rafael Ravedutti 2023-03-28 22:30:30 +02:00
parent 04ade6bcec
commit 5c000444a4
3 changed files with 58 additions and 26 deletions

View File

@ -431,6 +431,16 @@ void initMasks(Atom *atom) {
mask3 = (unsigned int)(0xf - 0x8 * cond0); mask3 = (unsigned int)(0xf - 0x8 * cond0);
atom->masks_2xnn_fn[cond0 * 2 + 0] = (mask1 << half_mask_bits) | mask0; atom->masks_2xnn_fn[cond0 * 2 + 0] = (mask1 << half_mask_bits) | mask0;
atom->masks_2xnn_fn[cond0 * 2 + 1] = (mask3 << half_mask_bits) | mask2; atom->masks_2xnn_fn[cond0 * 2 + 1] = (mask3 << half_mask_bits) | mask2;
atom->masks_4xn_hn[cond0 * 4 + 0] = (unsigned int)(0xf - 0x1 * cond0);
atom->masks_4xn_hn[cond0 * 4 + 1] = (unsigned int)(0xf - 0x3 * cond0);
atom->masks_4xn_hn[cond0 * 4 + 2] = (unsigned int)(0xf - 0x7 * cond0);
atom->masks_4xn_hn[cond0 * 4 + 3] = (unsigned int)(0xf - 0xf * cond0);
atom->masks_4xn_fn[cond0 * 4 + 0] = (unsigned int)(0xf - 0x1 * cond0);
atom->masks_4xn_fn[cond0 * 4 + 1] = (unsigned int)(0xf - 0x2 * cond0);
atom->masks_4xn_fn[cond0 * 4 + 2] = (unsigned int)(0xf - 0x4 * cond0);
atom->masks_4xn_fn[cond0 * 4 + 3] = (unsigned int)(0xf - 0x8 * cond0);
} }
#else #else
for(unsigned int cond0 = 0; cond0 < 2; cond0++) { for(unsigned int cond0 = 0; cond0 < 2; cond0++) {
@ -464,6 +474,28 @@ void initMasks(Atom *atom) {
atom->masks_2xnn_fn[cond0 * 4 + cond1 * 2 + 0] = (mask1 << half_mask_bits) | mask0; atom->masks_2xnn_fn[cond0 * 4 + cond1 * 2 + 0] = (mask1 << half_mask_bits) | mask0;
atom->masks_2xnn_fn[cond0 * 4 + cond1 * 2 + 1] = (mask3 << half_mask_bits) | mask2; atom->masks_2xnn_fn[cond0 * 4 + cond1 * 2 + 1] = (mask3 << half_mask_bits) | mask2;
#if CLUSTER_M < CLUSTER_N
atom->masks_4xn_hn[cond0 * 8 + cond1 * 4 + 0] = (unsigned int)(0xff - 0x1 * cond0 - 0x1f * cond1);
atom->masks_4xn_hn[cond0 * 8 + cond1 * 4 + 1] = (unsigned int)(0xff - 0x3 * cond0 - 0x3f * cond1);
atom->masks_4xn_hn[cond0 * 8 + cond1 * 4 + 2] = (unsigned int)(0xff - 0x7 * cond0 - 0x7f * cond1);
atom->masks_4xn_hn[cond0 * 8 + cond1 * 4 + 3] = (unsigned int)(0xff - 0xf * cond0 - 0xff * cond1);
atom->masks_4xn_fn[cond0 * 8 + cond1 * 4 + 0] = (unsigned int)(0xff - 0x1 * cond0 - 0x10 * cond1);
atom->masks_4xn_fn[cond0 * 8 + cond1 * 4 + 1] = (unsigned int)(0xff - 0x2 * cond0 - 0x20 * cond1);
atom->masks_4xn_fn[cond0 * 8 + cond1 * 4 + 2] = (unsigned int)(0xff - 0x4 * cond0 - 0x40 * cond1);
atom->masks_4xn_fn[cond0 * 8 + cond1 * 4 + 3] = (unsigned int)(0xff - 0x8 * cond0 - 0x80 * cond1);
#else
atom->masks_4xn_hn[cond0 * 8 + cond1 * 4 + 0] = (unsigned int)(0x3 - 0x1 * cond0);
atom->masks_4xn_hn[cond0 * 8 + cond1 * 4 + 1] = (unsigned int)(0x3 - 0x3 * cond0);
atom->masks_4xn_hn[cond0 * 8 + cond1 * 4 + 2] = (unsigned int)(0x3 - 0x3 * cond0 - 0x1 * cond1);
atom->masks_4xn_hn[cond0 * 8 + cond1 * 4 + 3] = (unsigned int)(0x3 - 0x3 * cond0 - 0x3 * cond1);
atom->masks_4xn_fn[cond0 * 8 + cond1 * 4 + 0] = (unsigned int)(0x3 - 0x1 * cond0);
atom->masks_4xn_fn[cond0 * 8 + cond1 * 4 + 0] = (unsigned int)(0x3 - 0x2 * cond0);
atom->masks_4xn_fn[cond0 * 8 + cond1 * 4 + 0] = (unsigned int)(0x3 - 0x1 * cond1);
atom->masks_4xn_fn[cond0 * 8 + cond1 * 4 + 0] = (unsigned int)(0x3 - 0x2 * cond1);
#endif
} }
} }
#endif #endif

View File

@ -16,6 +16,7 @@
#include <simd.h> #include <simd.h>
/*
static inline void gmx_load_simd_2xnn_interactions( static inline void gmx_load_simd_2xnn_interactions(
int excl, int excl,
MD_SIMD_BITMASK filter0, MD_SIMD_BITMASK filter2, MD_SIMD_BITMASK filter0, MD_SIMD_BITMASK filter2,
@ -39,6 +40,7 @@ static inline void gmx_load_simd_4xn_interactions(
*interact2 = cvtIB2B(simd_test_bits(mask_pr_S & filter2)); *interact2 = cvtIB2B(simd_test_bits(mask_pr_S & filter2));
*interact3 = cvtIB2B(simd_test_bits(mask_pr_S & filter3)); *interact3 = cvtIB2B(simd_test_bits(mask_pr_S & filter3));
} }
*/
double computeForceLJ_ref(Parameter *param, Atom *atom, Neighbor *neighbor, Stats *stats) { double computeForceLJ_ref(Parameter *param, Atom *atom, Neighbor *neighbor, Stats *stats) {
DEBUG_MESSAGE("computeForceLJ begin\n"); DEBUG_MESSAGE("computeForceLJ begin\n");
@ -655,24 +657,22 @@ double computeForceLJ_4xn_half(Parameter *param, Atom *atom, Neighbor *neighbor,
#if CLUSTER_M == CLUSTER_N #if CLUSTER_M == CLUSTER_N
unsigned int cond0 = (unsigned int)(cj == ci_cj0); unsigned int cond0 = (unsigned int)(cj == ci_cj0);
MD_SIMD_MASK excl_mask0 = simd_mask_from_u32((unsigned int)(0xf - 0x1 * cond0)); MD_SIMD_MASK excl_mask0 = simd_mask_from_u32(atom->masks_4xn_hn[cond0 * 4 + 0]);
MD_SIMD_MASK excl_mask1 = simd_mask_from_u32((unsigned int)(0xf - 0x3 * cond0)); MD_SIMD_MASK excl_mask1 = simd_mask_from_u32(atom->masks_4xn_hn[cond0 * 4 + 1]);
MD_SIMD_MASK excl_mask2 = simd_mask_from_u32((unsigned int)(0xf - 0x7 * cond0)); MD_SIMD_MASK excl_mask2 = simd_mask_from_u32(atom->masks_4xn_hn[cond0 * 4 + 2]);
MD_SIMD_MASK excl_mask3 = simd_mask_from_u32((unsigned int)(0xf - 0xf * cond0)); MD_SIMD_MASK excl_mask3 = simd_mask_from_u32(atom->masks_4xn_hn[cond0 * 4 + 3]);
#elif CLUSTER_M < CLUSTER_N #else
#if CLUSTER_M < CLUSTER_N
unsigned int cond0 = (unsigned int)((cj << 1) + 0 == ci); unsigned int cond0 = (unsigned int)((cj << 1) + 0 == ci);
unsigned int cond1 = (unsigned int)((cj << 1) + 1 == ci); unsigned int cond1 = (unsigned int)((cj << 1) + 1 == ci);
MD_SIMD_MASK excl_mask0 = simd_mask_from_u32((unsigned int)(0xff - 0x1 * cond0 - 0x1f * cond1));
MD_SIMD_MASK excl_mask1 = simd_mask_from_u32((unsigned int)(0xff - 0x3 * cond0 - 0x3f * cond1));
MD_SIMD_MASK excl_mask2 = simd_mask_from_u32((unsigned int)(0xff - 0x7 * cond0 - 0x7f * cond1));
MD_SIMD_MASK excl_mask3 = simd_mask_from_u32((unsigned int)(0xff - 0xf * cond0 - 0xff * cond1));
#else #else
unsigned int cond0 = (unsigned int)(cj == ci_cj0); unsigned int cond0 = (unsigned int)(cj == ci_cj0);
unsigned int cond1 = (unsigned int)(cj == ci_cj1); unsigned int cond1 = (unsigned int)(cj == ci_cj1);
MD_SIMD_MASK excl_mask0 = simd_mask_from_u32((unsigned int)(0x3 - 0x1 * cond0)); #endif
MD_SIMD_MASK excl_mask1 = simd_mask_from_u32((unsigned int)(0x3 - 0x3 * cond0)); MD_SIMD_MASK excl_mask0 = simd_mask_from_u32(atom->masks_4xn_hn[cond0 * 8 + cond1 * 4 + 0]);
MD_SIMD_MASK excl_mask2 = simd_mask_from_u32((unsigned int)(0x3 - 0x3 * cond0 - 0x1 * cond1)); MD_SIMD_MASK excl_mask1 = simd_mask_from_u32(atom->masks_4xn_hn[cond0 * 8 + cond1 * 4 + 1]);
MD_SIMD_MASK excl_mask3 = simd_mask_from_u32((unsigned int)(0x3 - 0x3 * cond0 - 0x3 * cond1)); MD_SIMD_MASK excl_mask2 = simd_mask_from_u32(atom->masks_4xn_hn[cond0 * 8 + cond1 * 4 + 2]);
MD_SIMD_MASK excl_mask3 = simd_mask_from_u32(atom->masks_4xn_hn[cond0 * 8 + cond1 * 4 + 3]);
#endif #endif
MD_SIMD_FLOAT rsq0 = simd_fma(delx0, delx0, simd_fma(dely0, dely0, simd_mul(delz0, delz0))); MD_SIMD_FLOAT rsq0 = simd_fma(delx0, delx0, simd_fma(dely0, dely0, simd_mul(delz0, delz0)));
@ -845,24 +845,22 @@ double computeForceLJ_4xn_full(Parameter *param, Atom *atom, Neighbor *neighbor,
#if CLUSTER_M == CLUSTER_N #if CLUSTER_M == CLUSTER_N
unsigned int cond0 = (unsigned int)(cj == ci_cj0); unsigned int cond0 = (unsigned int)(cj == ci_cj0);
MD_SIMD_MASK excl_mask0 = simd_mask_from_u32((unsigned int)(0xf - 0x1 * cond0)); MD_SIMD_MASK excl_mask0 = simd_mask_from_u32(atom->masks_4xn_fn[cond0 * 4 + 0]);
MD_SIMD_MASK excl_mask1 = simd_mask_from_u32((unsigned int)(0xf - 0x2 * cond0)); MD_SIMD_MASK excl_mask1 = simd_mask_from_u32(atom->masks_4xn_fn[cond0 * 4 + 1]);
MD_SIMD_MASK excl_mask2 = simd_mask_from_u32((unsigned int)(0xf - 0x4 * cond0)); MD_SIMD_MASK excl_mask2 = simd_mask_from_u32(atom->masks_4xn_fn[cond0 * 4 + 2]);
MD_SIMD_MASK excl_mask3 = simd_mask_from_u32((unsigned int)(0xf - 0x8 * cond0)); MD_SIMD_MASK excl_mask3 = simd_mask_from_u32(atom->masks_4xn_fn[cond0 * 4 + 3]);
#elif CLUSTER_M < CLUSTER_N #else
#if CLUSTER_M < CLUSTER_N
unsigned int cond0 = (unsigned int)((cj << 1) + 0 == ci); unsigned int cond0 = (unsigned int)((cj << 1) + 0 == ci);
unsigned int cond1 = (unsigned int)((cj << 1) + 1 == ci); unsigned int cond1 = (unsigned int)((cj << 1) + 1 == ci);
MD_SIMD_MASK excl_mask0 = simd_mask_from_u32((unsigned int)(0xff - 0x1 * cond0 - 0x10 * cond1));
MD_SIMD_MASK excl_mask1 = simd_mask_from_u32((unsigned int)(0xff - 0x2 * cond0 - 0x20 * cond1));
MD_SIMD_MASK excl_mask2 = simd_mask_from_u32((unsigned int)(0xff - 0x4 * cond0 - 0x40 * cond1));
MD_SIMD_MASK excl_mask3 = simd_mask_from_u32((unsigned int)(0xff - 0x8 * cond0 - 0x80 * cond1));
#else #else
unsigned int cond0 = (unsigned int)(cj == ci_cj0); unsigned int cond0 = (unsigned int)(cj == ci_cj0);
unsigned int cond1 = (unsigned int)(cj == ci_cj1); unsigned int cond1 = (unsigned int)(cj == ci_cj1);
MD_SIMD_MASK excl_mask0 = simd_mask_from_u32((unsigned int)(0x3 - 0x1 * cond0)); #endif
MD_SIMD_MASK excl_mask1 = simd_mask_from_u32((unsigned int)(0x3 - 0x2 * cond0)); MD_SIMD_MASK excl_mask0 = simd_mask_from_u32(atom->masks_4xn_fn[cond0 * 8 + cond1 * 4 + 0]);
MD_SIMD_MASK excl_mask2 = simd_mask_from_u32((unsigned int)(0x3 - 0x1 * cond1)); MD_SIMD_MASK excl_mask1 = simd_mask_from_u32(atom->masks_4xn_fn[cond0 * 8 + cond1 * 4 + 1]);
MD_SIMD_MASK excl_mask3 = simd_mask_from_u32((unsigned int)(0x3 - 0x2 * cond1)); MD_SIMD_MASK excl_mask2 = simd_mask_from_u32(atom->masks_4xn_fn[cond0 * 8 + cond1 * 4 + 2]);
MD_SIMD_MASK excl_mask3 = simd_mask_from_u32(atom->masks_4xn_fn[cond0 * 8 + cond1 * 4 + 3]);
#endif #endif
MD_SIMD_FLOAT rsq0 = simd_fma(delx0, delx0, simd_fma(dely0, dely0, simd_mul(delz0, delz0))); MD_SIMD_FLOAT rsq0 = simd_fma(delx0, delx0, simd_fma(dely0, dely0, simd_mul(delz0, delz0)));

View File

@ -126,6 +126,8 @@ typedef struct {
MD_FLOAT *diagonal_2xnn_j_minus_i; MD_FLOAT *diagonal_2xnn_j_minus_i;
unsigned int masks_2xnn_hn[8]; unsigned int masks_2xnn_hn[8];
unsigned int masks_2xnn_fn[8]; unsigned int masks_2xnn_fn[8];
unsigned int masks_4xn_hn[16];
unsigned int masks_4xn_fn[16];
} Atom; } Atom;
extern void initAtom(Atom*); extern void initAtom(Atom*);