Skip to content

Commit 878140b

Browse files
authored
Merge pull request brucefan1983#1170 from brucefan1983/charge_mode5
charge mode 5
2 parents 4ebefc7 + b464c13 commit 878140b

File tree

2 files changed

+75
-71
lines changed

2 files changed

+75
-71
lines changed

src/main_nep/nep_charge.cu

Lines changed: 69 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -146,12 +146,12 @@ static __global__ void find_descriptors_radial(
146146
float d12 = sqrt(x12 * x12 + y12 * y12 + z12 * z12);
147147
float fc12;
148148
int t2 = g_type[n2];
149-
float rc = (paramb.charge_mode == 4) ? paramb.rc_angular : paramb.rc_radial;
149+
float rc = (paramb.charge_mode >= 4) ? paramb.rc_angular : paramb.rc_radial;
150150
if (paramb.use_typewise_cutoff) {
151151
rc = min(
152152
(COVALENT_RADIUS[paramb.atomic_numbers[t1]] +
153153
COVALENT_RADIUS[paramb.atomic_numbers[t2]]) *
154-
((paramb.charge_mode == 4) ? paramb.typewise_cutoff_angular_factor : paramb.typewise_cutoff_radial_factor),
154+
((paramb.charge_mode >= 4) ? paramb.typewise_cutoff_angular_factor : paramb.typewise_cutoff_radial_factor),
155155
rc);
156156
}
157157
float rcinv = 1.0f / rc;
@@ -332,7 +332,7 @@ NEP_Charge::NEP_Charge(
332332
nep_data[device_id].S_imag.resize(Nc * charge_para.num_kpoints_max);
333333
nep_data[device_id].D_real.resize(N);
334334
nep_data[device_id].num_kpoints.resize(Nc);
335-
if (paramb.charge_mode == 4) {
335+
if (paramb.charge_mode >= 4) {
336336
nep_data[device_id].C6.resize(N);
337337
nep_data[device_id].C6_derivative.resize(N * annmb[device_id].dim);
338338
nep_data[device_id].D_C6.resize(N);
@@ -342,7 +342,7 @@ NEP_Charge::NEP_Charge(
342342

343343
void NEP_Charge::update_potential(Parameters& para, float* parameters, ANN& ann)
344344
{
345-
const int num_outputs = (para.charge_mode == 4) ? 3 : 2;
345+
const int num_outputs = (para.charge_mode >= 4) ? 3 : 2;
346346
float* pointer = parameters;
347347
for (int t = 0; t < paramb.num_types; ++t) {
348348
ann.w0[t] = pointer;
@@ -580,12 +580,12 @@ static __global__ void find_force_radial(
580580
float d12 = sqrt(r12[0] * r12[0] + r12[1] * r12[1] + r12[2] * r12[2]);
581581
float d12inv = 1.0f / d12;
582582
float fc12, fcp12;
583-
float rc = (paramb.charge_mode == 4) ? paramb.rc_angular : paramb.rc_radial;
583+
float rc = (paramb.charge_mode >= 4) ? paramb.rc_angular : paramb.rc_radial;
584584
if (paramb.use_typewise_cutoff) {
585585
rc = min(
586586
(COVALENT_RADIUS[paramb.atomic_numbers[t1]] +
587587
COVALENT_RADIUS[paramb.atomic_numbers[t2]]) *
588-
((paramb.charge_mode == 4) ? paramb.typewise_cutoff_angular_factor : paramb.typewise_cutoff_radial_factor),
588+
((paramb.charge_mode >= 4) ? paramb.typewise_cutoff_angular_factor : paramb.typewise_cutoff_radial_factor),
589589
rc);
590590
}
591591
float rcinv = 1.0f / rc;
@@ -603,7 +603,7 @@ static __global__ void find_force_radial(
603603
gnp12 += fnp12[k] * annmb.c[c_index];
604604
}
605605
float tmp12 = g_Fp[n1 + n * N] + g_charge_derivative[n1 + n * N] * g_D_real[n1];
606-
if (paramb.charge_mode == 4) {
606+
if (paramb.charge_mode >= 4) {
607607
tmp12 += g_C6_derivative[n1 + n * N] * g_D_C6[n1];
608608
}
609609
tmp12 *= gnp12 * d12inv;
@@ -671,7 +671,7 @@ static __global__ void find_force_angular(
671671
for (int d = 0; d < paramb.dim_angular; ++d) {
672672
float tmp = g_Fp[(paramb.n_max_radial + 1 + d) * N + n1]
673673
+ g_charge_derivative[(paramb.n_max_radial + 1 + d) * N + n1] * g_D_real[n1];
674-
if (paramb.charge_mode == 4) {
674+
if (paramb.charge_mode >= 4) {
675675
tmp += g_C6_derivative[(paramb.n_max_radial + 1 + d) * N + n1] * g_D_C6[n1];
676676
}
677677
Fp[d] = tmp;
@@ -763,12 +763,12 @@ static __global__ void find_bec_radial(
763763
float d12 = sqrt(r12[0] * r12[0] + r12[1] * r12[1] + r12[2] * r12[2]);
764764
float d12inv = 1.0f / d12;
765765
float fc12, fcp12;
766-
float rc = (paramb.charge_mode == 4) ? paramb.rc_angular : paramb.rc_radial;
766+
float rc = (paramb.charge_mode >= 4) ? paramb.rc_angular : paramb.rc_radial;
767767
if (paramb.use_typewise_cutoff) {
768768
rc = min(
769769
(COVALENT_RADIUS[paramb.atomic_numbers[t1]] +
770770
COVALENT_RADIUS[paramb.atomic_numbers[t2]]) *
771-
((paramb.charge_mode == 4) ? paramb.typewise_cutoff_angular_factor : paramb.typewise_cutoff_radial_factor),
771+
((paramb.charge_mode >= 4) ? paramb.typewise_cutoff_angular_factor : paramb.typewise_cutoff_radial_factor),
772772
rc);
773773
}
774774
float rcinv = 1.0f / rc;
@@ -1508,14 +1508,14 @@ void NEP_Charge::find_force(
15081508

15091509
find_descriptors_radial<<<grid_size, block_size>>>(
15101510
dataset[device_id].N,
1511-
(paramb.charge_mode == 4) ? nep_data[device_id].NN_angular.data() : nep_data[device_id].NN_radial.data(),
1512-
(paramb.charge_mode == 4) ? nep_data[device_id].NL_angular.data() : nep_data[device_id].NL_radial.data(),
1511+
(paramb.charge_mode >= 4) ? nep_data[device_id].NN_angular.data() : nep_data[device_id].NN_radial.data(),
1512+
(paramb.charge_mode >= 4) ? nep_data[device_id].NL_angular.data() : nep_data[device_id].NL_radial.data(),
15131513
paramb,
15141514
annmb[device_id],
15151515
dataset[device_id].type.data(),
1516-
(paramb.charge_mode == 4) ? nep_data[device_id].x12_angular.data() : nep_data[device_id].x12_radial.data(),
1517-
(paramb.charge_mode == 4) ? nep_data[device_id].y12_angular.data() : nep_data[device_id].y12_radial.data(),
1518-
(paramb.charge_mode == 4) ? nep_data[device_id].z12_angular.data() : nep_data[device_id].z12_radial.data(),
1516+
(paramb.charge_mode >= 4) ? nep_data[device_id].x12_angular.data() : nep_data[device_id].x12_radial.data(),
1517+
(paramb.charge_mode >= 4) ? nep_data[device_id].y12_angular.data() : nep_data[device_id].y12_radial.data(),
1518+
(paramb.charge_mode >= 4) ? nep_data[device_id].z12_angular.data() : nep_data[device_id].z12_radial.data(),
15191519
nep_data[device_id].descriptors.data());
15201520
GPU_CHECK_KERNEL
15211521

@@ -1580,7 +1580,7 @@ void NEP_Charge::find_force(
15801580
dataset[device_id].virial.data());
15811581
GPU_CHECK_KERNEL
15821582

1583-
if (paramb.charge_mode == 4) {
1583+
if (paramb.charge_mode >= 4) {
15841584
apply_ann_vdw<<<grid_size, block_size>>>(
15851585
dataset[device_id].N,
15861586
paramb,
@@ -1618,53 +1618,55 @@ void NEP_Charge::find_force(
16181618
dataset[device_id].charge_shifted.data());
16191619
GPU_CHECK_KERNEL
16201620

1621-
// get BEC (the diagonal part)
1622-
find_bec_diagonal<<<grid_size, block_size>>>(
1623-
dataset[device_id].N,
1624-
dataset[device_id].charge_shifted.data(),
1625-
dataset[device_id].bec.data());
1626-
GPU_CHECK_KERNEL
1621+
if (para.has_bec) {
1622+
// get BEC (the diagonal part)
1623+
find_bec_diagonal<<<grid_size, block_size>>>(
1624+
dataset[device_id].N,
1625+
dataset[device_id].charge_shifted.data(),
1626+
dataset[device_id].bec.data());
1627+
GPU_CHECK_KERNEL
16271628

1628-
// get BEC (radial descriptor part)
1629-
find_bec_radial<<<grid_size, block_size>>>(
1630-
dataset[device_id].N,
1631-
(paramb.charge_mode == 4) ? nep_data[device_id].NN_angular.data() : nep_data[device_id].NN_radial.data(),
1632-
(paramb.charge_mode == 4) ? nep_data[device_id].NL_angular.data() : nep_data[device_id].NL_radial.data(),
1633-
paramb,
1634-
annmb[device_id],
1635-
dataset[device_id].type.data(),
1636-
(paramb.charge_mode == 4) ? nep_data[device_id].x12_angular.data() : nep_data[device_id].x12_radial.data(),
1637-
(paramb.charge_mode == 4) ? nep_data[device_id].y12_angular.data() : nep_data[device_id].y12_radial.data(),
1638-
(paramb.charge_mode == 4) ? nep_data[device_id].z12_angular.data() : nep_data[device_id].z12_radial.data(),
1639-
nep_data[device_id].charge_derivative.data(),
1640-
dataset[device_id].bec.data());
1641-
GPU_CHECK_KERNEL
1629+
// get BEC (radial descriptor part)
1630+
find_bec_radial<<<grid_size, block_size>>>(
1631+
dataset[device_id].N,
1632+
(paramb.charge_mode >= 4) ? nep_data[device_id].NN_angular.data() : nep_data[device_id].NN_radial.data(),
1633+
(paramb.charge_mode >= 4) ? nep_data[device_id].NL_angular.data() : nep_data[device_id].NL_radial.data(),
1634+
paramb,
1635+
annmb[device_id],
1636+
dataset[device_id].type.data(),
1637+
(paramb.charge_mode >= 4) ? nep_data[device_id].x12_angular.data() : nep_data[device_id].x12_radial.data(),
1638+
(paramb.charge_mode >= 4) ? nep_data[device_id].y12_angular.data() : nep_data[device_id].y12_radial.data(),
1639+
(paramb.charge_mode >= 4) ? nep_data[device_id].z12_angular.data() : nep_data[device_id].z12_radial.data(),
1640+
nep_data[device_id].charge_derivative.data(),
1641+
dataset[device_id].bec.data());
1642+
GPU_CHECK_KERNEL
16421643

1643-
// get BEC (angular descriptor part)
1644-
find_bec_angular<<<grid_size, block_size>>>(
1645-
dataset[device_id].N,
1646-
nep_data[device_id].NN_angular.data(),
1647-
nep_data[device_id].NL_angular.data(),
1648-
paramb,
1649-
annmb[device_id],
1650-
dataset[device_id].type.data(),
1651-
nep_data[device_id].x12_angular.data(),
1652-
nep_data[device_id].y12_angular.data(),
1653-
nep_data[device_id].z12_angular.data(),
1654-
nep_data[device_id].charge_derivative.data(),
1655-
nep_data[device_id].sum_fxyz.data(),
1656-
dataset[device_id].bec.data());
1657-
GPU_CHECK_KERNEL
1644+
// get BEC (angular descriptor part)
1645+
find_bec_angular<<<grid_size, block_size>>>(
1646+
dataset[device_id].N,
1647+
nep_data[device_id].NN_angular.data(),
1648+
nep_data[device_id].NL_angular.data(),
1649+
paramb,
1650+
annmb[device_id],
1651+
dataset[device_id].type.data(),
1652+
nep_data[device_id].x12_angular.data(),
1653+
nep_data[device_id].y12_angular.data(),
1654+
nep_data[device_id].z12_angular.data(),
1655+
nep_data[device_id].charge_derivative.data(),
1656+
nep_data[device_id].sum_fxyz.data(),
1657+
dataset[device_id].bec.data());
1658+
GPU_CHECK_KERNEL
16581659

1659-
// scale q to q * sqrt(epsilon_inf)
1660-
scale_bec<<<grid_size, block_size>>>(
1661-
dataset[device_id].N,
1662-
annmb[device_id].sqrt_epsilon_inf,
1663-
dataset[device_id].bec.data());
1664-
GPU_CHECK_KERNEL
1660+
// scale q to q * sqrt(epsilon_inf)
1661+
scale_bec<<<grid_size, block_size>>>(
1662+
dataset[device_id].N,
1663+
annmb[device_id].sqrt_epsilon_inf,
1664+
dataset[device_id].bec.data());
1665+
GPU_CHECK_KERNEL
1666+
}
16651667

16661668
// reciprocal space
1667-
if (paramb.charge_mode != 3) {
1669+
if (paramb.charge_mode == 1 || paramb.charge_mode == 2 || paramb.charge_mode == 4) {
16681670
find_k_and_G<<<(dataset[device_id].Nc - 1) / 64 + 1, 64>>>(
16691671
dataset[device_id].Nc,
16701672
charge_para.num_kpoints_max,
@@ -1741,8 +1743,8 @@ void NEP_Charge::find_force(
17411743
GPU_CHECK_KERNEL
17421744
}
17431745

1744-
// mode 3 has real space only
1745-
if (paramb.charge_mode == 3) {
1746+
// modes 3 and 5 has real space only
1747+
if (paramb.charge_mode == 3 || paramb.charge_mode == 5) {
17461748
find_force_charge_real_space_only<<<grid_size, block_size>>>(
17471749
dataset[device_id].N,
17481750
charge_para.alpha,
@@ -1764,8 +1766,8 @@ void NEP_Charge::find_force(
17641766
GPU_CHECK_KERNEL
17651767
}
17661768

1767-
// mode 4 has vdw
1768-
if (paramb.charge_mode == 4) {
1769+
// modes 4 and 5 has vdw
1770+
if (paramb.charge_mode >= 4) {
17691771
find_force_vdw_static<<<grid_size, block_size>>>(
17701772
dataset[device_id].N,
17711773
nep_data[device_id].NN_radial.data(),
@@ -1785,14 +1787,14 @@ void NEP_Charge::find_force(
17851787

17861788
find_force_radial<<<grid_size, block_size>>>(
17871789
dataset[device_id].N,
1788-
(paramb.charge_mode == 4) ? nep_data[device_id].NN_angular.data() : nep_data[device_id].NN_radial.data(),
1789-
(paramb.charge_mode == 4) ? nep_data[device_id].NL_angular.data() : nep_data[device_id].NL_radial.data(),
1790+
(paramb.charge_mode >= 4) ? nep_data[device_id].NN_angular.data() : nep_data[device_id].NN_radial.data(),
1791+
(paramb.charge_mode >= 4) ? nep_data[device_id].NL_angular.data() : nep_data[device_id].NL_radial.data(),
17901792
paramb,
17911793
annmb[device_id],
17921794
dataset[device_id].type.data(),
1793-
(paramb.charge_mode == 4) ? nep_data[device_id].x12_angular.data() : nep_data[device_id].x12_radial.data(),
1794-
(paramb.charge_mode == 4) ? nep_data[device_id].y12_angular.data() : nep_data[device_id].y12_radial.data(),
1795-
(paramb.charge_mode == 4) ? nep_data[device_id].z12_angular.data() : nep_data[device_id].z12_radial.data(),
1795+
(paramb.charge_mode >= 4) ? nep_data[device_id].x12_angular.data() : nep_data[device_id].x12_radial.data(),
1796+
(paramb.charge_mode >= 4) ? nep_data[device_id].y12_angular.data() : nep_data[device_id].y12_radial.data(),
1797+
(paramb.charge_mode >= 4) ? nep_data[device_id].z12_angular.data() : nep_data[device_id].z12_radial.data(),
17961798
nep_data[device_id].Fp.data(),
17971799
nep_data[device_id].charge_derivative.data(),
17981800
nep_data[device_id].D_real.data(),

src/main_nep/parameters.cu

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ void Parameters::calculate_parameters()
222222
if (charge_mode) {
223223
number_of_variables_ann_1 += num_neurons1;
224224
number_of_variables_ann += num_neurons1 * num_types + 1;
225-
if (charge_mode == 4) {
225+
if (charge_mode >= 4) {
226226
number_of_variables_ann_1 += num_neurons1;
227227
number_of_variables_ann += num_neurons1 * num_types;
228228
}
@@ -458,7 +458,9 @@ void Parameters::report_inputs()
458458
} else if (charge_mode == 3) {
459459
printf(" (input) use NEP-Charge and include real-space only; lambda_q = %g.\n", lambda_q);
460460
} else if (charge_mode == 4) {
461-
printf(" (input) use NEP-Charge-VdW; lambda_q = %g.\n", lambda_q);
461+
printf(" (input) use NEP-Charge-VdW and include k-space only; lambda_q = %g.\n", lambda_q);
462+
} else if (charge_mode == 5) {
463+
printf(" (input) use NEP-Charge-VdW and include real-space only; lambda_q = %g.\n", lambda_q);
462464
}
463465
}
464466

@@ -1303,8 +1305,8 @@ void Parameters::parse_charge_mode(const char** param, int num_param)
13031305
if (!is_valid_int(param[1], &charge_mode)) {
13041306
PRINT_INPUT_ERROR("charge mode should be an integer.\n");
13051307
}
1306-
if (charge_mode != 0 && charge_mode != 1 && charge_mode != 2 && charge_mode != 3 && charge_mode != 4) {
1307-
PRINT_INPUT_ERROR("charge mode should be 0 or 1 or 2 or 3 or 4.");
1308+
if (charge_mode < 0 || charge_mode > 5) {
1309+
PRINT_INPUT_ERROR("charge mode should be 0 or 1 or 2 or 3 or 4 or 5.");
13081310
}
13091311
}
13101312

0 commit comments

Comments
 (0)