@@ -352,11 +352,9 @@ NEP_Charge::NEP_Charge(const char* file_potential, const int num_atoms)
352352 nep_data.Fp .resize (num_atoms * annmb.dim );
353353 nep_data.sum_fxyz .resize (
354354 num_atoms * (paramb.n_max_angular + 1 ) * ((paramb.L_max + 1 ) * (paramb.L_max + 1 ) - 1 ));
355- nep_data.cell_count .resize (num_atoms);
356- nep_data.cell_count_sum .resize (num_atoms);
357- nep_data.cell_contents .resize (num_atoms);
358355 nep_data.cpu_NN_radial .resize (num_atoms);
359356 nep_data.cpu_NN_angular .resize (num_atoms);
357+ neighbor.initialize (rc, num_atoms, paramb.MN_radial );
360358
361359 initialize_dftd3 ();
362360}
@@ -391,17 +389,13 @@ static __global__ void find_neighbor_list_large_box(
391389 const int N,
392390 const int N1,
393391 const int N2,
394- const int nx,
395- const int ny,
396- const int nz,
397392 const Box box,
398393 const int * g_type,
399- const int * __restrict__ g_cell_count,
400- const int * __restrict__ g_cell_count_sum,
401- const int * __restrict__ g_cell_contents,
402394 const double * __restrict__ g_x,
403395 const double * __restrict__ g_y,
404396 const double * __restrict__ g_z,
397+ const int * __restrict__ g_NN_global,
398+ const int * __restrict__ g_NL_global,
405399 int * g_NN_radial,
406400 int * g_NL_radial,
407401 int * g_NN_angular,
@@ -418,75 +412,21 @@ static __global__ void find_neighbor_list_large_box(
418412 int count_radial = 0 ;
419413 int count_angular = 0 ;
420414
421- int cell_id;
422- int cell_id_x;
423- int cell_id_y;
424- int cell_id_z;
425- find_cell_id (
426- box,
427- x1,
428- y1,
429- z1,
430- 2 .0f * paramb.rcinv_radial ,
431- nx,
432- ny,
433- nz,
434- cell_id_x,
435- cell_id_y,
436- cell_id_z,
437- cell_id);
438-
439- const int z_lim = box.pbc_z ? 2 : 0 ;
440- const int y_lim = box.pbc_y ? 2 : 0 ;
441- const int x_lim = box.pbc_x ? 2 : 0 ;
442-
443- for (int zz = -z_lim; zz <= z_lim; ++zz) {
444- for (int yy = -y_lim; yy <= y_lim; ++yy) {
445- for (int xx = -x_lim; xx <= x_lim; ++xx) {
446- int neighbor_cell = cell_id + zz * nx * ny + yy * nx + xx;
447- if (cell_id_x + xx < 0 )
448- neighbor_cell += nx;
449- else if (cell_id_x + xx >= nx)
450- neighbor_cell -= nx;
451- if (cell_id_y + yy < 0 )
452- neighbor_cell += ny * nx;
453- else if (cell_id_y + yy >= ny)
454- neighbor_cell -= ny * nx;
455- if (cell_id_z + zz < 0 )
456- neighbor_cell += nz * ny * nx;
457- else if (cell_id_z + zz >= nz)
458- neighbor_cell -= nz * ny * nx;
459-
460- const int num_atoms_neighbor_cell = g_cell_count[neighbor_cell];
461- const int num_atoms_previous_cells = g_cell_count_sum[neighbor_cell];
462-
463- for (int m = 0 ; m < num_atoms_neighbor_cell; ++m) {
464- const int n2 = g_cell_contents[num_atoms_previous_cells + m];
465-
466- if (n2 < N1 || n2 >= N2 || n1 == n2) {
467- continue ;
468- }
469-
470- float x12 = g_x[n2] - x1;
471- float y12 = g_y[n2] - y1;
472- float z12 = g_z[n2] - z1;
473- apply_mic (box, x12, y12, z12);
474- float d12_square = x12 * x12 + y12 * y12 + z12 * z12;
475-
476- float rc_radial = paramb.rc_radial ;
477- float rc_angular = paramb.rc_angular ;
478-
479- if (d12_square >= rc_radial * rc_radial) {
480- continue ;
481- }
482-
483- g_NL_radial[count_radial++ * N + n1] = n2;
484-
485- if (d12_square < rc_angular * rc_angular) {
486- g_NL_angular[count_angular++ * N + n1] = n2;
487- }
488- }
489- }
415+ for (int i1 = 0 ; i1 < g_NN_global[n1]; ++i1) {
416+ int n2 = g_NL_global[n1 + N * i1];
417+ float x12 = g_x[n2] - x1;
418+ float y12 = g_y[n2] - y1;
419+ float z12 = g_z[n2] - z1;
420+ apply_mic (box, x12, y12, z12);
421+ float d12_square = x12 * x12 + y12 * y12 + z12 * z12;
422+ float rc_radial = paramb.rc_radial ;
423+ float rc_angular = paramb.rc_angular ;
424+ if (d12_square >= rc_radial * rc_radial) {
425+ continue ;
426+ }
427+ g_NL_radial[count_radial++ * N + n1] = n2;
428+ if (d12_square < rc_angular * rc_angular) {
429+ g_NL_angular[count_angular++ * N + n1] = n2;
490430 }
491431 }
492432
@@ -1419,36 +1359,24 @@ void NEP_Charge::compute_large_box(
14191359 const int N = type.size ();
14201360 const int grid_size = (N2 - N1 - 1 ) / BLOCK_SIZE + 1 ;
14211361
1422- const double rc_cell_list = 0.5 * rc;
1423-
1424- int num_bins[3 ];
1425- box.get_num_bins (rc_cell_list, num_bins);
1426-
1427- find_cell_list (
1428- rc_cell_list,
1429- num_bins,
1430- box,
1431- position_per_atom,
1432- nep_data.cell_count ,
1433- nep_data.cell_count_sum ,
1434- nep_data.cell_contents );
1362+ neighbor.find_neighbor_global (
1363+ rc,
1364+ box,
1365+ type,
1366+ position_per_atom);
14351367
14361368 find_neighbor_list_large_box<<<grid_size, BLOCK_SIZE>>> (
14371369 paramb,
14381370 N,
14391371 N1,
14401372 N2,
1441- num_bins[0 ],
1442- num_bins[1 ],
1443- num_bins[2 ],
14441373 box,
14451374 type.data (),
1446- nep_data.cell_count .data (),
1447- nep_data.cell_count_sum .data (),
1448- nep_data.cell_contents .data (),
14491375 position_per_atom.data (),
14501376 position_per_atom.data () + N,
14511377 position_per_atom.data () + N * 2 ,
1378+ neighbor.NN .data (),
1379+ neighbor.NL .data (),
14521380 nep_data.NN_radial .data (),
14531381 nep_data.NL_radial .data (),
14541382 nep_data.NN_angular .data (),
@@ -1477,14 +1405,6 @@ void NEP_Charge::compute_large_box(
14771405 output_file.close ();
14781406 }
14791407
1480- gpu_sort_neighbor_list<<<N, paramb.MN_radial, paramb.MN_radial * sizeof (int )>>> (
1481- N, nep_data.NN_radial .data (), nep_data.NL_radial .data ());
1482- GPU_CHECK_KERNEL
1483-
1484- gpu_sort_neighbor_list<<<N, paramb.MN_angular, paramb.MN_angular * sizeof (int )>>> (
1485- N, nep_data.NN_angular .data (), nep_data.NL_angular .data ());
1486- GPU_CHECK_KERNEL
1487-
14881408 find_descriptor<<<grid_size, BLOCK_SIZE>>> (
14891409 paramb,
14901410 annmb,
0 commit comments