@@ -1075,7 +1075,8 @@ static __global__ void find_k_and_G(
10751075static __global__ void zero_total_charge (
10761076 const int * Na,
10771077 const int * Na_sum,
1078- float * g_charge)
1078+ float * g_charge,
1079+ float * g_charge_shifted)
10791080{
10801081 int tid = threadIdx .x ;
10811082 int N1 = Na_sum[blockIdx .x ];
@@ -1102,7 +1103,7 @@ static __global__ void zero_total_charge(
11021103 for (int batch = 0 ; batch < number_of_batches; ++batch) {
11031104 int n = tid + batch * 1024 + N1;
11041105 if (n < N2) {
1105- g_charge[n] -= s_charge[0 ] / (N2 - N1);
1106+ g_charge_shifted[n] = g_charge[n] - s_charge[0 ] / (N2 - N1);
11061107 }
11071108 }
11081109}
@@ -1248,7 +1249,8 @@ void NEP_Charge::find_force(
12481249 zero_total_charge<<<dataset[device_id].Nc, 1024 >>> (
12491250 dataset[device_id].Na .data (),
12501251 dataset[device_id].Na_sum .data (),
1251- dataset[device_id].charge .data ());
1252+ dataset[device_id].charge .data (),
1253+ dataset[device_id].charge_shifted .data ());
12521254 GPU_CHECK_KERNEL
12531255
12541256 if (paramb.charge_mode != 3 ) {
@@ -1269,7 +1271,7 @@ void NEP_Charge::find_force(
12691271 charge_para.num_kpoints_max ,
12701272 dataset[device_id].Na .data (),
12711273 dataset[device_id].Na_sum .data (),
1272- dataset[device_id].charge .data (),
1274+ dataset[device_id].charge_shifted .data (),
12731275 dataset[device_id].r .data (),
12741276 dataset[device_id].r .data () + dataset[device_id].N ,
12751277 dataset[device_id].r .data () + dataset[device_id].N * 2 ,
@@ -1287,7 +1289,7 @@ void NEP_Charge::find_force(
12871289 charge_para.alpha_factor ,
12881290 dataset[device_id].Na .data (),
12891291 dataset[device_id].Na_sum .data (),
1290- dataset[device_id].charge .data (),
1292+ dataset[device_id].charge_shifted .data (),
12911293 dataset[device_id].r .data (),
12921294 dataset[device_id].r .data () + dataset[device_id].N ,
12931295 dataset[device_id].r .data () + dataset[device_id].N * 2 ,
@@ -1316,7 +1318,7 @@ void NEP_Charge::find_force(
13161318 charge_para.two_alpha_over_sqrt_pi ,
13171319 nep_data[device_id].NN_radial .data (),
13181320 nep_data[device_id].NL_radial .data (),
1319- dataset[device_id].charge .data (),
1321+ dataset[device_id].charge_shifted .data (),
13201322 nep_data[device_id].x12_radial .data (),
13211323 nep_data[device_id].y12_radial .data (),
13221324 nep_data[device_id].z12_radial .data (),
@@ -1336,7 +1338,7 @@ void NEP_Charge::find_force(
13361338 charge_para.B ,
13371339 nep_data[device_id].NN_radial .data (),
13381340 nep_data[device_id].NL_radial .data (),
1339- dataset[device_id].charge .data (),
1341+ dataset[device_id].charge_shifted .data (),
13401342 nep_data[device_id].x12_radial .data (),
13411343 nep_data[device_id].y12_radial .data (),
13421344 nep_data[device_id].z12_radial .data (),
0 commit comments