diff --git a/src/neighbor.cu b/src/neighbor.cu index d4b1ccd..fa6e947 100644 --- a/src/neighbor.cu +++ b/src/neighbor.cu @@ -70,6 +70,31 @@ __device__ int coord2bin_device(MD_FLOAT xin, MD_FLOAT yin, MD_FLOAT zin, return (iz * np.mbiny * np.mbinx + iy * np.mbinx + ix + 1); } +/* sorts the contents of a bin to make it comparable to the CPU version */ +/* uses bubble sort since atoms per bin should be relatively small and can be done in situ */ +__global__ void sort_bin_contents_kernel(int* bincount, int* bins, int mbins, int atoms_per_bin){ + const int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i >= mbins){ + return; + } + + int atoms_in_bin = bincount[i]; + int* bin_ptr = &bins[i * atoms_per_bin]; + int sorted; + do { + sorted = 1; + int tmp; + for(int index = 0; index < atoms_in_bin - 1; index++){ + if (bin_ptr[index] > bin_ptr[index + 1]){ + tmp = bin_ptr[index]; + bin_ptr[index] = bin_ptr[index + 1]; + bin_ptr[index + 1] = tmp; + sorted = 0; + } + } + } while (!sorted) +} + __global__ void binatoms_kernel(Atom a, int* bincount, int* bins, int atoms_per_bin, Neighbor_params np, int *resize_needed){ Atom* atom = &a; const int i = blockIdx.x * blockDim.x + threadIdx.x; @@ -571,6 +596,12 @@ void binatoms_cuda(Atom* c_atom, Binning* c_binning, int* c_resize_needed, Neigh checkCUDAError("binatoms_cuda c_binning->bins resize malloc", cudaMalloc(&c_binning->bins, c_binning->mbins * c_binning->atoms_per_bin * sizeof(int)) ); } } + atoms_per_bin = c_binning->atoms_per_bin; + const int sortBlocks = ceil((float)mbins / (float)threads_per_block); + /*void sort_bin_contents_kernel(int* bincount, int* bins, int mbins, int atoms_per_bin)*/ + sort_bin_contents_kernel<<>>(c_binning->bincount, c_binning->bins, c_binning->mbins, c_binning->atoms_per_bin); + checkCUDAError( "PeekAtLastError sort_bin_contents kernel", cudaPeekAtLastError() ); + checkCUDAError( "DeviceSync sort_bin_contents kernel", cudaDeviceSynchronize() ); } void buildNeighbor_cuda(Atom *atom, Neighbor *neighbor, Atom *c_atom, Neighbor *c_neighbor, const int num_threads_per_block)