Skip to content

Commit b7fc02c

Browse files
authored
Merge pull request brucefan1983#1168 from brucefan1983/no_bec_output
only output bec when there are target values
2 parents eef3872 + cefc083 commit b7fc02c

4 files changed

Lines changed: 35 additions & 23 deletions

File tree

src/main_nep/dataset.cu

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -992,21 +992,23 @@ std::vector<float> Dataset::get_rmse_charge(Parameters& para, int device_id)
992992
}
993993
}
994994

995-
gpu_sum_bec_error<<<Nc, block_size, sizeof(float) * block_size>>>(
996-
N,
997-
Na.data(),
998-
Na_sum.data(),
999-
bec.data(),
1000-
bec_ref_gpu.data(),
1001-
error_gpu.data());
1002-
CHECK(gpuMemcpy(error_cpu.data(), error_gpu.data(), mem, gpuMemcpyDeviceToHost));
1003-
for (int n = 0; n < Nc; ++n) {
1004-
if (structures[n].has_bec) {
1005-
float rmse_temp = error_cpu[n];
1006-
for (int t = 0; t < para.num_types + 1; ++t) {
1007-
if (has_type[t * Nc + n]) {
1008-
rmse_array[t] += rmse_temp / (Na_cpu[n]);
1009-
count_array[t] += 9;
995+
if (para.has_bec) {
996+
gpu_sum_bec_error<<<Nc, block_size, sizeof(float) * block_size>>>(
997+
N,
998+
Na.data(),
999+
Na_sum.data(),
1000+
bec.data(),
1001+
bec_ref_gpu.data(),
1002+
error_gpu.data());
1003+
CHECK(gpuMemcpy(error_cpu.data(), error_gpu.data(), mem, gpuMemcpyDeviceToHost));
1004+
for (int n = 0; n < Nc; ++n) {
1005+
if (structures[n].has_bec) {
1006+
float rmse_temp = error_cpu[n];
1007+
for (int t = 0; t < para.num_types + 1; ++t) {
1008+
if (has_type[t * Nc + n]) {
1009+
rmse_array[t] += rmse_temp / (Na_cpu[n]);
1010+
count_array[t] += 9;
1011+
}
10101012
}
10111013
}
10121014
}

src/main_nep/fitness.cu

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -546,9 +546,11 @@ void Fitness::report_error(
546546
FILE* fid_charge = my_fopen("charge_test.out", "w");
547547
update_charge(fid_charge, test_set[0]);
548548
fclose(fid_charge);
549-
FILE* fid_bec = my_fopen("bec_test.out", "w");
550-
update_bec(fid_bec, test_set[0]);
551-
fclose(fid_bec);
549+
if (para.has_bec) {
550+
FILE* fid_bec = my_fopen("bec_test.out", "w");
551+
update_bec(fid_bec, test_set[0]);
552+
fclose(fid_bec);
553+
}
552554
}
553555
} else if (para.train_mode == 1) {
554556
FILE* fid_dipole = my_fopen("dipole_test.out", "w");
@@ -643,15 +645,19 @@ void Fitness::predict(Parameters& para, float* elite)
643645
FILE* fid_bec;
644646
if (para.charge_mode) {
645647
fid_charge = my_fopen("charge_train.out", "w");
646-
fid_bec = my_fopen("bec_train.out", "w");
648+
if (para.has_bec) {
649+
fid_bec = my_fopen("bec_train.out", "w");
650+
}
647651
}
648652
for (int batch_id = 0; batch_id < num_batches; ++batch_id) {
649653
potential->find_force(para, elite, train_set[batch_id], false, true, 1);
650654
update_energy_force_virial(
651655
fid_energy, fid_force, fid_virial, fid_stress, train_set[batch_id][0]);
652656
if (para.charge_mode) {
653657
update_charge(fid_charge, train_set[batch_id][0]);
654-
update_bec(fid_bec, train_set[batch_id][0]);
658+
if (para.has_bec) {
659+
update_bec(fid_bec, train_set[batch_id][0]);
660+
}
655661
}
656662
}
657663
fclose(fid_energy);
@@ -660,7 +666,9 @@ void Fitness::predict(Parameters& para, float* elite)
660666
fclose(fid_stress);
661667
if (para.charge_mode) {
662668
fclose(fid_charge);
663-
fclose(fid_bec);
669+
if (para.has_bec) {
670+
fclose(fid_bec);
671+
}
664672
}
665673
} else if (para.train_mode == 1) {
666674
FILE* fid_dipole = my_fopen("dipole_train.out", "w");

src/main_nep/parameters.cuh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ public:
6767
float typewise_cutoff_zbl_factor;
6868
int output_descriptor;
6969
int charge_mode; // add dynamic charge to NEP potential model
70+
bool has_bec = false; // check if there are target BEC values
7071
int fine_tune = 0; // fine_tune option; 0=no, 1=yes
7172
std::string fine_tune_nep_txt = "";
7273
std::string fine_tune_nep_restart = "";

src/main_nep/structure.cu

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ static void read_force(
183183
}
184184

185185
static void read_one_structure(
186-
const Parameters& para,
186+
Parameters& para,
187187
std::ifstream& input,
188188
Structure& structure,
189189
std::string& xyz_filename,
@@ -457,6 +457,7 @@ static void read_one_structure(
457457
if (sub_tokens[k * 3] == "bec") {
458458
bec_position = k;
459459
structure.has_bec = true;
460+
para.has_bec = true;
460461
}
461462
}
462463
if (species_position < 0) {
@@ -516,7 +517,7 @@ static void read_one_structure(
516517
}
517518

518519
static void read_exyz(
519-
const Parameters& para,
520+
Parameters& para,
520521
std::ifstream& input,
521522
std::vector<Structure>& structures,
522523
std::string& xyz_filename)

0 commit comments

Comments
 (0)