Skip to content

Commit 9a27378

Browse files
committed
enable charge loss term
1 parent dceb6e5 commit 9a27378

File tree

4 files changed

+20
-7
lines changed

4 files changed

+20
-7
lines changed

src/main_nep/dataset.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ void Dataset::initialize_gpu_data(Parameters& para)
159159
std::vector<int> type_cpu(N);
160160

161161
charge.resize(N);
162+
charge_shifted.resize(N);
162163
energy.resize(N);
163164
virial.resize(N * 6);
164165
force.resize(N * 3);

src/main_nep/dataset.cuh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ public:
4040
GPU_Vector<int> num_cell; // number of cells in the expanded box (3 components)
4141

4242
GPU_Vector<float> charge; // calculated charge in GPU
43+
GPU_Vector<float> charge_shifted; // shifted charge in GPU
4344
GPU_Vector<float> energy; // calculated energy in GPU
4445
GPU_Vector<float> virial; // calculated virial in GPU
4546
GPU_Vector<float> force; // calculated force in GPU

src/main_nep/fitness.cu

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ void Fitness::compute(
215215
para, energy_shift_per_structure_not_used, true, true, m);
216216
auto rmse_force_array = train_set[batch_id][m].get_rmse_force(para, true, m);
217217
auto rmse_virial_array = train_set[batch_id][m].get_rmse_virial(para, true, m);
218+
auto rmse_charge_array = train_set[batch_id][m].get_rmse_charge(para, m);
218219
for (int t = 0; t <= para.num_types; ++t) {
219220
// energy
220221
float old_value = fitness[deviceCount * n + m + (7 * t + 3) * para.population_size];
@@ -234,6 +235,14 @@ void Fitness::compute(
234235
new_value = old_value * old_value * count_batch + new_value * new_value;
235236
new_value = sqrt(new_value / (count_batch + 1));
236237
fitness[deviceCount * n + m + (7 * t + 5) * para.population_size] = new_value;
238+
// charge
239+
if (para.charge_mode) {
240+
old_value = fitness[deviceCount * n + m + (7 * t + 6) * para.population_size];
241+
new_value = para.lambda_q * rmse_charge_array[t];
242+
new_value = old_value * old_value * count_batch + new_value * new_value;
243+
new_value = sqrt(new_value / (count_batch + 1));
244+
fitness[deviceCount * n + m + (7 * t + 6) * para.population_size] = new_value;
245+
}
237246
}
238247
}
239248
}

src/main_nep/nep_charge.cu

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1075,7 +1075,8 @@ static __global__ void find_k_and_G(
10751075
static __global__ void zero_total_charge(
10761076
const int* Na,
10771077
const int* Na_sum,
1078-
float* g_charge)
1078+
float* g_charge,
1079+
float* g_charge_shifted)
10791080
{
10801081
int tid = threadIdx.x;
10811082
int N1 = Na_sum[blockIdx.x];
@@ -1102,7 +1103,7 @@ static __global__ void zero_total_charge(
11021103
for (int batch = 0; batch < number_of_batches; ++batch) {
11031104
int n = tid + batch * 1024 + N1;
11041105
if (n < N2) {
1105-
g_charge[n] -= s_charge[0] / (N2 - N1);
1106+
g_charge_shifted[n] = g_charge[n] - s_charge[0] / (N2 - N1);
11061107
}
11071108
}
11081109
}
@@ -1248,7 +1249,8 @@ void NEP_Charge::find_force(
12481249
zero_total_charge<<<dataset[device_id].Nc, 1024>>>(
12491250
dataset[device_id].Na.data(),
12501251
dataset[device_id].Na_sum.data(),
1251-
dataset[device_id].charge.data());
1252+
dataset[device_id].charge.data(),
1253+
dataset[device_id].charge_shifted.data());
12521254
GPU_CHECK_KERNEL
12531255

12541256
if (paramb.charge_mode != 3) {
@@ -1269,7 +1271,7 @@ void NEP_Charge::find_force(
12691271
charge_para.num_kpoints_max,
12701272
dataset[device_id].Na.data(),
12711273
dataset[device_id].Na_sum.data(),
1272-
dataset[device_id].charge.data(),
1274+
dataset[device_id].charge_shifted.data(),
12731275
dataset[device_id].r.data(),
12741276
dataset[device_id].r.data() + dataset[device_id].N,
12751277
dataset[device_id].r.data() + dataset[device_id].N * 2,
@@ -1287,7 +1289,7 @@ void NEP_Charge::find_force(
12871289
charge_para.alpha_factor,
12881290
dataset[device_id].Na.data(),
12891291
dataset[device_id].Na_sum.data(),
1290-
dataset[device_id].charge.data(),
1292+
dataset[device_id].charge_shifted.data(),
12911293
dataset[device_id].r.data(),
12921294
dataset[device_id].r.data() + dataset[device_id].N,
12931295
dataset[device_id].r.data() + dataset[device_id].N * 2,
@@ -1316,7 +1318,7 @@ void NEP_Charge::find_force(
13161318
charge_para.two_alpha_over_sqrt_pi,
13171319
nep_data[device_id].NN_radial.data(),
13181320
nep_data[device_id].NL_radial.data(),
1319-
dataset[device_id].charge.data(),
1321+
dataset[device_id].charge_shifted.data(),
13201322
nep_data[device_id].x12_radial.data(),
13211323
nep_data[device_id].y12_radial.data(),
13221324
nep_data[device_id].z12_radial.data(),
@@ -1336,7 +1338,7 @@ void NEP_Charge::find_force(
13361338
charge_para.B,
13371339
nep_data[device_id].NN_radial.data(),
13381340
nep_data[device_id].NL_radial.data(),
1339-
dataset[device_id].charge.data(),
1341+
dataset[device_id].charge_shifted.data(),
13401342
nep_data[device_id].x12_radial.data(),
13411343
nep_data[device_id].y12_radial.data(),
13421344
nep_data[device_id].z12_radial.data(),

0 commit comments

Comments
 (0)