diff --git a/Makefile b/Makefile index f6126bf..76bf162 100644 --- a/Makefile +++ b/Makefile @@ -77,6 +77,10 @@ ifeq ($(strip $(__ISA_AVX__)),true) DEFINES += -D__ISA_AVX__ endif +ifeq ($(strip $(__ISA_AVX_FMA__)),true) + DEFINES += -D__ISA_AVX_FMA__ +endif + ifeq ($(strip $(__ISA_AVX2__)),true) DEFINES += -D__ISA_AVX2__ endif diff --git a/common/includes/simd/avx_double.h b/common/includes/simd/avx_double.h index da7f6bf..40774e1 100644 --- a/common/includes/simd/avx_double.h +++ b/common/includes/simd/avx_double.h @@ -61,7 +61,11 @@ static inline MD_FLOAT simd_incr_reduced_sum(MD_FLOAT *m, MD_SIMD_FLOAT v0, MD_S static inline MD_SIMD_FLOAT select_by_mask(MD_SIMD_FLOAT a, MD_SIMD_MASK m) { return _mm256_and_pd(a, m); } static inline MD_SIMD_FLOAT simd_reciprocal(MD_SIMD_FLOAT a) { return _mm256_cvtps_pd(_mm_rcp_ps(_mm256_cvtpd_ps(a))); } +#ifdef __ISA_AVX_FMA__ +static inline MD_SIMD_FLOAT simd_fma(MD_SIMD_FLOAT a, MD_SIMD_FLOAT b, MD_SIMD_FLOAT c) { return _mm256_fmadd_pd(a, b, c); } +#else static inline MD_SIMD_FLOAT simd_fma(MD_SIMD_FLOAT a, MD_SIMD_FLOAT b, MD_SIMD_FLOAT c) { return simd_add(simd_mul(a, b), c); } +#endif static inline MD_SIMD_FLOAT simd_masked_add(MD_SIMD_FLOAT a, MD_SIMD_FLOAT b, MD_SIMD_MASK m) { return simd_add(a, _mm256_and_pd(b, m)); } static inline MD_SIMD_MASK simd_mask_cond_lt(MD_SIMD_FLOAT a, MD_SIMD_FLOAT b) { return _mm256_cmp_pd(a, b, _CMP_LT_OQ); } static inline MD_SIMD_MASK simd_mask_int_cond_lt(MD_SIMD_INT a, MD_SIMD_INT b) { return _mm256_cvtepi32_pd(_mm_cmplt_epi32(a, b)); } diff --git a/config.mk b/config.mk index 696f1e5..677b142 100644 --- a/config.mk +++ b/config.mk @@ -1,6 +1,6 @@ # Compiler tag (GCC/CLANG/ICC/ONEAPI/NVCC) TAG ?= ICC -# Instruction set (SSE/AVX/AVX2/AVX512) +# Instruction set (SSE/AVX/AVX_FMA/AVX2/AVX512) ISA ?= AVX512 # Optimization scheme (lammps/gromacs/clusters_per_bin) OPT_SCHEME ?= lammps diff --git a/include_ISA.mk b/include_ISA.mk index 1866e49..a8c8e24 100644 --- a/include_ISA.mk +++ b/include_ISA.mk @@ -4,6 +4,10 @@ ifeq ($(strip $(ISA)), SSE) else ifeq ($(strip $(ISA)), AVX) __ISA_AVX__=true __SIMD_WIDTH_DBL__=4 +else ifeq ($(strip $(ISA)), AVX_FMA) + __ISA_AVX__=true + __ISA_AVX_FMA__=true + __SIMD_WIDTH_DBL__=4 else ifeq ($(strip $(ISA)), AVX2) __ISA_AVX2__=true #__SIMD_KERNEL__=true