Skip to content

Commit ce9baec

Browse files
committed
qnep neighbor skin
1 parent 96a591c commit ce9baec

File tree

2 files changed

+27
-108
lines changed

2 files changed

+27
-108
lines changed

src/force/nep_charge.cu

Lines changed: 25 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -352,11 +352,9 @@ NEP_Charge::NEP_Charge(const char* file_potential, const int num_atoms)
352352
nep_data.Fp.resize(num_atoms * annmb.dim);
353353
nep_data.sum_fxyz.resize(
354354
num_atoms * (paramb.n_max_angular + 1) * ((paramb.L_max + 1) * (paramb.L_max + 1) - 1));
355-
nep_data.cell_count.resize(num_atoms);
356-
nep_data.cell_count_sum.resize(num_atoms);
357-
nep_data.cell_contents.resize(num_atoms);
358355
nep_data.cpu_NN_radial.resize(num_atoms);
359356
nep_data.cpu_NN_angular.resize(num_atoms);
357+
neighbor.initialize(rc, num_atoms, paramb.MN_radial);
360358

361359
initialize_dftd3();
362360
}
@@ -391,17 +389,13 @@ static __global__ void find_neighbor_list_large_box(
391389
const int N,
392390
const int N1,
393391
const int N2,
394-
const int nx,
395-
const int ny,
396-
const int nz,
397392
const Box box,
398393
const int* g_type,
399-
const int* __restrict__ g_cell_count,
400-
const int* __restrict__ g_cell_count_sum,
401-
const int* __restrict__ g_cell_contents,
402394
const double* __restrict__ g_x,
403395
const double* __restrict__ g_y,
404396
const double* __restrict__ g_z,
397+
const int* __restrict__ g_NN_global,
398+
const int* __restrict__ g_NL_global,
405399
int* g_NN_radial,
406400
int* g_NL_radial,
407401
int* g_NN_angular,
@@ -418,75 +412,21 @@ static __global__ void find_neighbor_list_large_box(
418412
int count_radial = 0;
419413
int count_angular = 0;
420414

421-
int cell_id;
422-
int cell_id_x;
423-
int cell_id_y;
424-
int cell_id_z;
425-
find_cell_id(
426-
box,
427-
x1,
428-
y1,
429-
z1,
430-
2.0f * paramb.rcinv_radial,
431-
nx,
432-
ny,
433-
nz,
434-
cell_id_x,
435-
cell_id_y,
436-
cell_id_z,
437-
cell_id);
438-
439-
const int z_lim = box.pbc_z ? 2 : 0;
440-
const int y_lim = box.pbc_y ? 2 : 0;
441-
const int x_lim = box.pbc_x ? 2 : 0;
442-
443-
for (int zz = -z_lim; zz <= z_lim; ++zz) {
444-
for (int yy = -y_lim; yy <= y_lim; ++yy) {
445-
for (int xx = -x_lim; xx <= x_lim; ++xx) {
446-
int neighbor_cell = cell_id + zz * nx * ny + yy * nx + xx;
447-
if (cell_id_x + xx < 0)
448-
neighbor_cell += nx;
449-
else if (cell_id_x + xx >= nx)
450-
neighbor_cell -= nx;
451-
if (cell_id_y + yy < 0)
452-
neighbor_cell += ny * nx;
453-
else if (cell_id_y + yy >= ny)
454-
neighbor_cell -= ny * nx;
455-
if (cell_id_z + zz < 0)
456-
neighbor_cell += nz * ny * nx;
457-
else if (cell_id_z + zz >= nz)
458-
neighbor_cell -= nz * ny * nx;
459-
460-
const int num_atoms_neighbor_cell = g_cell_count[neighbor_cell];
461-
const int num_atoms_previous_cells = g_cell_count_sum[neighbor_cell];
462-
463-
for (int m = 0; m < num_atoms_neighbor_cell; ++m) {
464-
const int n2 = g_cell_contents[num_atoms_previous_cells + m];
465-
466-
if (n2 < N1 || n2 >= N2 || n1 == n2) {
467-
continue;
468-
}
469-
470-
float x12 = g_x[n2] - x1;
471-
float y12 = g_y[n2] - y1;
472-
float z12 = g_z[n2] - z1;
473-
apply_mic(box, x12, y12, z12);
474-
float d12_square = x12 * x12 + y12 * y12 + z12 * z12;
475-
476-
float rc_radial = paramb.rc_radial;
477-
float rc_angular = paramb.rc_angular;
478-
479-
if (d12_square >= rc_radial * rc_radial) {
480-
continue;
481-
}
482-
483-
g_NL_radial[count_radial++ * N + n1] = n2;
484-
485-
if (d12_square < rc_angular * rc_angular) {
486-
g_NL_angular[count_angular++ * N + n1] = n2;
487-
}
488-
}
489-
}
415+
for (int i1 = 0; i1 < g_NN_global[n1]; ++i1) {
416+
int n2 = g_NL_global[n1 + N * i1];
417+
float x12 = g_x[n2] - x1;
418+
float y12 = g_y[n2] - y1;
419+
float z12 = g_z[n2] - z1;
420+
apply_mic(box, x12, y12, z12);
421+
float d12_square = x12 * x12 + y12 * y12 + z12 * z12;
422+
float rc_radial = paramb.rc_radial;
423+
float rc_angular = paramb.rc_angular;
424+
if (d12_square >= rc_radial * rc_radial) {
425+
continue;
426+
}
427+
g_NL_radial[count_radial++ * N + n1] = n2;
428+
if (d12_square < rc_angular * rc_angular) {
429+
g_NL_angular[count_angular++ * N + n1] = n2;
490430
}
491431
}
492432

@@ -1419,36 +1359,24 @@ void NEP_Charge::compute_large_box(
14191359
const int N = type.size();
14201360
const int grid_size = (N2 - N1 - 1) / BLOCK_SIZE + 1;
14211361

1422-
const double rc_cell_list = 0.5 * rc;
1423-
1424-
int num_bins[3];
1425-
box.get_num_bins(rc_cell_list, num_bins);
1426-
1427-
find_cell_list(
1428-
rc_cell_list,
1429-
num_bins,
1430-
box,
1431-
position_per_atom,
1432-
nep_data.cell_count,
1433-
nep_data.cell_count_sum,
1434-
nep_data.cell_contents);
1362+
neighbor.find_neighbor_global(
1363+
rc,
1364+
box,
1365+
type,
1366+
position_per_atom);
14351367

14361368
find_neighbor_list_large_box<<<grid_size, BLOCK_SIZE>>>(
14371369
paramb,
14381370
N,
14391371
N1,
14401372
N2,
1441-
num_bins[0],
1442-
num_bins[1],
1443-
num_bins[2],
14441373
box,
14451374
type.data(),
1446-
nep_data.cell_count.data(),
1447-
nep_data.cell_count_sum.data(),
1448-
nep_data.cell_contents.data(),
14491375
position_per_atom.data(),
14501376
position_per_atom.data() + N,
14511377
position_per_atom.data() + N * 2,
1378+
neighbor.NN.data(),
1379+
neighbor.NL.data(),
14521380
nep_data.NN_radial.data(),
14531381
nep_data.NL_radial.data(),
14541382
nep_data.NN_angular.data(),
@@ -1477,14 +1405,6 @@ void NEP_Charge::compute_large_box(
14771405
output_file.close();
14781406
}
14791407

1480-
gpu_sort_neighbor_list<<<N, paramb.MN_radial, paramb.MN_radial * sizeof(int)>>>(
1481-
N, nep_data.NN_radial.data(), nep_data.NL_radial.data());
1482-
GPU_CHECK_KERNEL
1483-
1484-
gpu_sort_neighbor_list<<<N, paramb.MN_angular, paramb.MN_angular * sizeof(int)>>>(
1485-
N, nep_data.NN_angular.data(), nep_data.NL_angular.data());
1486-
GPU_CHECK_KERNEL
1487-
14881408
find_descriptor<<<grid_size, BLOCK_SIZE>>>(
14891409
paramb,
14901410
annmb,

src/force/nep_charge.cuh

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
#pragma once
1717
#include "dftd3.cuh"
18+
#include "neighbor.cuh"
1819
#include "potential.cuh"
1920
#include "utilities/common.cuh"
2021
#include "utilities/gpu_vector.cuh"
@@ -32,9 +33,6 @@ struct NEP_Charge_Data {
3233
GPU_Vector<int> NN_angular; // angular neighbor list
3334
GPU_Vector<int> NL_angular; // angular neighbor list
3435
GPU_Vector<float> parameters; // parameters to be optimized
35-
GPU_Vector<int> cell_count;
36-
GPU_Vector<int> cell_count_sum;
37-
GPU_Vector<int> cell_contents;
3836
std::vector<int> cpu_NN_radial;
3937
std::vector<int> cpu_NN_angular;
4038
GPU_Vector<float> kx;
@@ -153,6 +151,7 @@ private:
153151
Charge_Para charge_para;
154152
Ewald ewald;
155153
PPPM pppm;
154+
Neighbor neighbor;
156155

157156
void update_potential(float* parameters, ANN& ann);
158157

0 commit comments

Comments
 (0)