Skip to content

Commit 9cd52a8

Browse files
authored
Merge pull request brucefan1983#1293 from brucefan1983/try_population_mix
lambda_q and lambda_ z
2 parents 62884e3 + 68835a0 commit 9cd52a8

11 files changed

Lines changed: 287 additions & 113 deletions

File tree

doc/nep/input_parameters/index.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ Below you can find a listing of keywords for the ``nep.in`` input file.
2828
lambda_e
2929
lambda_f
3030
lambda_v
31+
lambda_q
32+
lambda_z
3133
atomic_v
3234
lambda_shear
3335
force_delta
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
.. _kw_lambda_q:
2+
.. index::
3+
single: lambda_q (keyword in nep.in)
4+
5+
:attr:`lambda_q`
6+
================
7+
8+
This keyword sets the weight :math:`\lambda_q` of the loss term associated with the **total charge** (the charge for a whole structure) in the :ref:`loss function <nep_loss_function>`.
9+
The syntax is::
10+
11+
lambda_q <weight>
12+
13+
Here, :attr:`<weight>` represents :math:`\lambda_q`, which must satisfy :math:`\lambda_q \geq 0` and defaults to :math:`\lambda_q = 0.5`.
14+
15+
This keyword is only relevant for the qNEP models.
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
.. _kw_lambda_z:
2+
.. index::
3+
single: lambda_z (keyword in nep.in)
4+
5+
:attr:`lambda_z`
6+
================
7+
8+
This keyword sets the weight :math:`\lambda_z` of the loss term associated with the **Born effective charge** in the :ref:`loss function <nep_loss_function>`.
9+
The syntax is::
10+
11+
lambda_z <weight>
12+
13+
Here, :attr:`<weight>` represents :math:`\lambda_z`, which must satisfy :math:`\lambda_z \geq 0` and defaults to :math:`\lambda_z = 0.5`.
14+
15+
This keyword is only relevant for the qNEP models.

src/main_nep/dataset.cu

Lines changed: 43 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -976,32 +976,52 @@ std::vector<float> Dataset::get_rmse_charge(Parameters& para, int device_id)
976976
error_gpu.data());
977977
CHECK(gpuMemcpy(error_cpu.data(), error_gpu.data(), mem, gpuMemcpyDeviceToHost));
978978
for (int n = 0; n < Nc; ++n) {
979-
float rmse_temp = error_cpu[n];
980-
for (int t = 0; t < para.num_types + 1; ++t) {
981-
if (has_type[t * Nc + n]) {
982-
rmse_array[t] += rmse_temp;
983-
count_array[t] += 1;
984-
}
979+
float rmse_temp = error_cpu[n];
980+
for (int t = 0; t < para.num_types + 1; ++t) {
981+
if (has_type[t * Nc + n]) {
982+
rmse_array[t] += rmse_temp;
983+
count_array[t] += 1;
985984
}
985+
}
986986
}
987987

988-
if (para.has_bec) {
989-
gpu_sum_bec_error<<<Nc, block_size, sizeof(float) * block_size>>>(
990-
N,
991-
Na.data(),
992-
Na_sum.data(),
993-
bec.data(),
994-
bec_ref_gpu.data(),
995-
error_gpu.data());
996-
CHECK(gpuMemcpy(error_cpu.data(), error_gpu.data(), mem, gpuMemcpyDeviceToHost));
997-
for (int n = 0; n < Nc; ++n) {
998-
if (structures[n].has_bec) {
999-
float rmse_temp = error_cpu[n];
1000-
for (int t = 0; t < para.num_types + 1; ++t) {
1001-
if (has_type[t * Nc + n]) {
1002-
rmse_array[t] += rmse_temp / (Na_cpu[n]);
1003-
count_array[t] += 9;
1004-
}
988+
for (int t = 0; t <= para.num_types; ++t) {
989+
if (count_array[t] > 0) {
990+
rmse_array[t] = sqrt(rmse_array[t] / count_array[t]);
991+
}
992+
}
993+
return rmse_array;
994+
}
995+
996+
std::vector<float> Dataset::get_rmse_bec(Parameters& para, int device_id)
997+
{
998+
std::vector<float> rmse_array(para.num_types + 1, 0.0f);
999+
if (!(para.charge_mode && para.has_bec)) {
1000+
return rmse_array;
1001+
}
1002+
1003+
CHECK(gpuSetDevice(device_id));
1004+
1005+
std::vector<int> count_array(para.num_types + 1, 0);
1006+
1007+
int mem = sizeof(float) * Nc;
1008+
const int block_size = 256;
1009+
1010+
gpu_sum_bec_error<<<Nc, block_size, sizeof(float) * block_size>>>(
1011+
N,
1012+
Na.data(),
1013+
Na_sum.data(),
1014+
bec.data(),
1015+
bec_ref_gpu.data(),
1016+
error_gpu.data());
1017+
CHECK(gpuMemcpy(error_cpu.data(), error_gpu.data(), mem, gpuMemcpyDeviceToHost));
1018+
for (int n = 0; n < Nc; ++n) {
1019+
if (structures[n].has_bec) {
1020+
float rmse_temp = error_cpu[n];
1021+
for (int t = 0; t < para.num_types + 1; ++t) {
1022+
if (has_type[t * Nc + n]) {
1023+
rmse_array[t] += rmse_temp / (Na_cpu[n]);
1024+
count_array[t] += 9;
10051025
}
10061026
}
10071027
}
@@ -1014,4 +1034,3 @@ std::vector<float> Dataset::get_rmse_charge(Parameters& para, int device_id)
10141034
}
10151035
return rmse_array;
10161036
}
1017-

src/main_nep/dataset.cuh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ public:
9292
std::vector<float> get_rmse_virial(Parameters& para, const bool use_weight, int device_id);
9393
std::vector<float> get_rmse_avirial(Parameters& para, const bool use_weight, int device_id);
9494
std::vector<float> get_rmse_charge(Parameters& para, int device_id);
95+
std::vector<float> get_rmse_bec(Parameters& para, int device_id);
9596

9697
private:
9798
void copy_structures(std::vector<Structure>& structures_input, int n1, int n2);

src/main_nep/fitness.cu

Lines changed: 108 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,14 @@ Fitness::~Fitness()
149149
}
150150

151151
void Fitness::compute(
152-
const int generation, Parameters& para, const float* population, float* fitness)
152+
const int generation,
153+
Parameters& para,
154+
const float* population,
155+
float* fitness_energy,
156+
float* fitness_force,
157+
float* fitness_virial,
158+
float* fitness_charge,
159+
float* fitness_bec)
153160
{
154161
int deviceCount;
155162
CHECK(gpuGetDeviceCount(&deviceCount));
@@ -180,20 +187,19 @@ void Fitness::compute(
180187
auto rmse_force_array = train_set[batch_id][m].get_rmse_force(para, true, m);
181188
auto rmse_virial_array = train_set[batch_id][m].get_rmse_virial(para, true, m);
182189
auto rmse_charge_array = train_set[batch_id][m].get_rmse_charge(para, m);
190+
auto rmse_bec_array = train_set[batch_id][m].get_rmse_bec(para, m);
183191

184192
for (int t = 0; t <= para.num_types; ++t) {
185-
fitness[deviceCount * n + m + (7 * t + 3) * para.population_size] =
193+
fitness_energy[deviceCount * n + m + t * para.population_size] =
186194
para.lambda_e * rmse_energy_array[t];
187-
fitness[deviceCount * n + m + (7 * t + 4) * para.population_size] =
195+
fitness_force[deviceCount * n + m + t * para.population_size] =
188196
para.lambda_f * rmse_force_array[t];
189-
fitness[deviceCount * n + m + (7 * t + 5) * para.population_size] =
197+
fitness_virial[deviceCount * n + m + t * para.population_size] =
190198
para.lambda_v * rmse_virial_array[t];
191-
if (para.charge_mode) {
192-
fitness[deviceCount * n + m + (7 * t + 6) * para.population_size] =
193-
para.lambda_q * rmse_charge_array[t];
194-
} else {
195-
fitness[deviceCount * n + m + (7 * t + 6) * para.population_size] = 0.0f;
196-
}
199+
fitness_charge[deviceCount * n + m + t * para.population_size] =
200+
para.lambda_q * rmse_charge_array[t];
201+
fitness_bec[deviceCount * n + m + t * para.population_size] =
202+
para.lambda_z * rmse_bec_array[t];
197203
}
198204
}
199205
}
@@ -216,33 +222,38 @@ void Fitness::compute(
216222
auto rmse_force_array = train_set[batch_id][m].get_rmse_force(para, true, m);
217223
auto rmse_virial_array = train_set[batch_id][m].get_rmse_virial(para, true, m);
218224
auto rmse_charge_array = train_set[batch_id][m].get_rmse_charge(para, m);
225+
auto rmse_bec_array = train_set[batch_id][m].get_rmse_bec(para, m);
219226
for (int t = 0; t <= para.num_types; ++t) {
220227
// energy
221-
float old_value = fitness[deviceCount * n + m + (7 * t + 3) * para.population_size];
228+
float old_value = fitness_energy[deviceCount * n + m + t * para.population_size];
222229
float new_value = para.lambda_e * rmse_energy_array[t];
223230
new_value = old_value * old_value * count_batch + new_value * new_value;
224231
new_value = sqrt(new_value / (count_batch + 1));
225-
fitness[deviceCount * n + m + (7 * t + 3) * para.population_size] = new_value;
232+
fitness_energy[deviceCount * n + m + t * para.population_size] = new_value;
226233
// force
227-
old_value = fitness[deviceCount * n + m + (7 * t + 4) * para.population_size];
234+
old_value = fitness_force[deviceCount * n + m + t * para.population_size];
228235
new_value = para.lambda_f * rmse_force_array[t];
229236
new_value = old_value * old_value * count_batch + new_value * new_value;
230237
new_value = sqrt(new_value / (count_batch + 1));
231-
fitness[deviceCount * n + m + (7 * t + 4) * para.population_size] = new_value;
238+
fitness_force[deviceCount * n + m + t * para.population_size] = new_value;
232239
// virial
233-
old_value = fitness[deviceCount * n + m + (7 * t + 5) * para.population_size];
240+
old_value = fitness_virial[deviceCount * n + m + t * para.population_size];
234241
new_value = para.lambda_v * rmse_virial_array[t];
235242
new_value = old_value * old_value * count_batch + new_value * new_value;
236243
new_value = sqrt(new_value / (count_batch + 1));
237-
fitness[deviceCount * n + m + (7 * t + 5) * para.population_size] = new_value;
244+
fitness_virial[deviceCount * n + m + t * para.population_size] = new_value;
238245
// charge
239-
if (para.charge_mode) {
240-
old_value = fitness[deviceCount * n + m + (7 * t + 6) * para.population_size];
241-
new_value = para.lambda_q * rmse_charge_array[t];
242-
new_value = old_value * old_value * count_batch + new_value * new_value;
243-
new_value = sqrt(new_value / (count_batch + 1));
244-
fitness[deviceCount * n + m + (7 * t + 6) * para.population_size] = new_value;
245-
}
246+
old_value = fitness_charge[deviceCount * n + m + t * para.population_size];
247+
new_value = para.lambda_q * rmse_charge_array[t];
248+
new_value = old_value * old_value * count_batch + new_value * new_value;
249+
new_value = sqrt(new_value / (count_batch + 1));
250+
fitness_charge[deviceCount * n + m + t * para.population_size] = new_value;
251+
// BEC
252+
old_value = fitness_bec[deviceCount * n + m + t * para.population_size];
253+
new_value = para.lambda_z * rmse_bec_array[t];
254+
new_value = old_value * old_value * count_batch + new_value * new_value;
255+
new_value = sqrt(new_value / (count_batch + 1));
256+
fitness_bec[deviceCount * n + m + t * para.population_size] = new_value;
246257
}
247258
}
248259
}
@@ -435,10 +446,14 @@ void Fitness::report_error(
435446
train_set[batch_id][0].get_rmse_energy(para, energy_shift_per_structure, false, true, 0);
436447
auto rmse_force_train_array = train_set[batch_id][0].get_rmse_force(para, false, 0);
437448
auto rmse_virial_train_array = train_set[batch_id][0].get_rmse_virial(para, false, 0);
449+
auto rmse_charge_train_array = train_set[batch_id][0].get_rmse_charge(para, 0);
450+
auto rmse_bec_train_array = train_set[batch_id][0].get_rmse_bec(para, 0);
438451

439452
float rmse_energy_train = rmse_energy_train_array.back();
440453
float rmse_force_train = rmse_force_train_array.back();
441454
float rmse_virial_train = rmse_virial_train_array.back();
455+
float rmse_charge_train = rmse_charge_train_array.back();
456+
float rmse_bec_train = rmse_bec_train_array.back();
442457

443458
// correct the last bias parameter in the NN
444459
if (para.train_mode == 0 || para.train_mode == 3) {
@@ -448,16 +463,22 @@ void Fitness::report_error(
448463
float rmse_energy_test = 0.0f;
449464
float rmse_force_test = 0.0f;
450465
float rmse_virial_test = 0.0f;
466+
float rmse_charge_test = 0.0f;
467+
float rmse_bec_test = 0.0f;
451468
if (has_test_set) {
452469
potential->find_force(para, elite, test_set, false, true, 1);
453470
float energy_shift_per_structure_not_used;
454471
auto rmse_energy_test_array =
455472
test_set[0].get_rmse_energy(para, energy_shift_per_structure_not_used, false, false, 0);
456473
auto rmse_force_test_array = test_set[0].get_rmse_force(para, false, 0);
457474
auto rmse_virial_test_array = test_set[0].get_rmse_virial(para, false, 0);
475+
auto rmse_charge_test_array = test_set[0].get_rmse_charge(para, 0);
476+
auto rmse_bec_test_array = test_set[0].get_rmse_bec(para, 0);
458477
rmse_energy_test = rmse_energy_test_array.back();
459478
rmse_force_test = rmse_force_test_array.back();
460479
rmse_virial_test = rmse_virial_test_array.back();
480+
rmse_charge_test = rmse_charge_test_array.back();
481+
rmse_bec_test = rmse_bec_test_array.back();
461482
}
462483

463484
FILE* fid_nep = my_fopen("nep.txt", "w");
@@ -475,32 +496,71 @@ void Fitness::report_error(
475496
}
476497

477498
if (para.train_mode == 0 || para.train_mode == 3) {
478-
printf(
479-
"%-8d%-11.5f%-11.5f%-11.5f%-13.5f%-13.5f%-13.5f%-13.5f%-13.5f%-13.5f\n",
480-
generation + 1,
481-
loss_total,
482-
loss_L1,
483-
loss_L2,
484-
rmse_energy_train,
485-
rmse_force_train,
486-
rmse_virial_train,
487-
rmse_energy_test,
488-
rmse_force_test,
489-
rmse_virial_test);
490-
fprintf(
491-
fid_loss_out,
492-
"%-8d%-11.5f%-11.5f%-11.5f%-13.5f%-13.5f%-13.5f%-13.5f%-13.5f%-13.5f\n",
493-
generation + 1,
494-
loss_total,
495-
loss_L1,
496-
loss_L2,
497-
rmse_energy_train,
498-
rmse_force_train,
499-
rmse_virial_train,
500-
rmse_energy_test,
501-
rmse_force_test,
502-
rmse_virial_test);
499+
if (!para.charge_mode) {
500+
// NEP models
501+
printf(
502+
"%-8d%-11.5f%-11.5f%-11.5f%-13.5f%-13.5f%-13.5f%-13.5f%-13.5f%-13.5f\n",
503+
generation + 1,
504+
loss_total,
505+
loss_L1,
506+
loss_L2,
507+
rmse_energy_train,
508+
rmse_force_train,
509+
rmse_virial_train,
510+
rmse_energy_test,
511+
rmse_force_test,
512+
rmse_virial_test);
513+
fprintf(
514+
fid_loss_out,
515+
"%-8d%-11.5f%-11.5f%-11.5f%-13.5f%-13.5f%-13.5f%-13.5f%-13.5f%-13.5f\n",
516+
generation + 1,
517+
loss_total,
518+
loss_L1,
519+
loss_L2,
520+
rmse_energy_train,
521+
rmse_force_train,
522+
rmse_virial_train,
523+
rmse_energy_test,
524+
rmse_force_test,
525+
rmse_virial_test);
526+
} else {
527+
// qNEP models:
528+
printf(
529+
"%-8d%-9.5f%-9.5f%-9.5f%-9.5f%-9.5f%-9.5f%-9.5f%-9.5f%-9.5f%-9.5f%-9.5f%-9.5f%-9.5f\n",
530+
generation + 1,
531+
loss_total,
532+
loss_L1,
533+
loss_L2,
534+
rmse_energy_train,
535+
rmse_force_train,
536+
rmse_virial_train,
537+
rmse_charge_train,
538+
rmse_bec_train,
539+
rmse_energy_test,
540+
rmse_force_test,
541+
rmse_virial_test,
542+
rmse_charge_test,
543+
rmse_bec_test);
544+
fprintf(
545+
fid_loss_out,
546+
"%-8d%-9.5f%-9.5f%-9.5f%-9.5f%-9.5f%-9.5f%-9.5f%-9.5f%-9.5f%-9.5f%-9.5f%-9.5f%-9.5f\n",
547+
generation + 1,
548+
loss_total,
549+
loss_L1,
550+
loss_L2,
551+
rmse_energy_train,
552+
rmse_force_train,
553+
rmse_virial_train,
554+
rmse_charge_train,
555+
rmse_bec_train,
556+
rmse_energy_test,
557+
rmse_force_test,
558+
rmse_virial_test,
559+
rmse_charge_test,
560+
rmse_bec_test);
561+
}
503562
} else {
563+
// TNEP models:
504564
printf(
505565
"%-8d%-11.5f%-11.5f%-11.5f%-13.5f%-13.5f\n",
506566
generation + 1,

src/main_nep/fitness.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ class Fitness
2828
public:
2929
Fitness(Parameters& para);
3030
~Fitness();
31-
void compute(const int generation, Parameters& para, const float*, float*);
31+
void compute(const int generation, Parameters& para, const float*, float*, float*, float*, float*, float*);
3232
void report_error(
3333
Parameters& para,
3434
const int generation,

0 commit comments

Comments
 (0)