Skip to content

Commit 73f3600

Browse files
committed
report charge and bec rmse
1 parent dc5c93a commit 73f3600

2 files changed

Lines changed: 105 additions & 37 deletions

File tree

src/main_nep/fitness.cu

Lines changed: 74 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -446,10 +446,14 @@ void Fitness::report_error(
446446
train_set[batch_id][0].get_rmse_energy(para, energy_shift_per_structure, false, true, 0);
447447
auto rmse_force_train_array = train_set[batch_id][0].get_rmse_force(para, false, 0);
448448
auto rmse_virial_train_array = train_set[batch_id][0].get_rmse_virial(para, false, 0);
449+
auto rmse_charge_train_array = train_set[batch_id][0].get_rmse_charge(para, 0);
450+
auto rmse_bec_train_array = train_set[batch_id][0].get_rmse_bec(para, 0);
449451

450452
float rmse_energy_train = rmse_energy_train_array.back();
451453
float rmse_force_train = rmse_force_train_array.back();
452454
float rmse_virial_train = rmse_virial_train_array.back();
455+
float rmse_charge_train = rmse_charge_train_array.back();
456+
float rmse_bec_train = rmse_bec_train_array.back();
453457

454458
// correct the last bias parameter in the NN
455459
if (para.train_mode == 0 || para.train_mode == 3) {
@@ -459,16 +463,22 @@ void Fitness::report_error(
459463
float rmse_energy_test = 0.0f;
460464
float rmse_force_test = 0.0f;
461465
float rmse_virial_test = 0.0f;
466+
float rmse_charge_test = 0.0f;
467+
float rmse_bec_test = 0.0f;
462468
if (has_test_set) {
463469
potential->find_force(para, elite, test_set, false, true, 1);
464470
float energy_shift_per_structure_not_used;
465471
auto rmse_energy_test_array =
466472
test_set[0].get_rmse_energy(para, energy_shift_per_structure_not_used, false, false, 0);
467473
auto rmse_force_test_array = test_set[0].get_rmse_force(para, false, 0);
468474
auto rmse_virial_test_array = test_set[0].get_rmse_virial(para, false, 0);
475+
auto rmse_charge_test_array = test_set[0].get_rmse_charge(para, 0);
476+
auto rmse_bec_test_array = test_set[0].get_rmse_bec(para, 0);
469477
rmse_energy_test = rmse_energy_test_array.back();
470478
rmse_force_test = rmse_force_test_array.back();
471479
rmse_virial_test = rmse_virial_test_array.back();
480+
rmse_charge_test = rmse_charge_test_array.back();
481+
rmse_bec_test = rmse_bec_test_array.back();
472482
}
473483

474484
FILE* fid_nep = my_fopen("nep.txt", "w");
@@ -486,32 +496,71 @@ void Fitness::report_error(
486496
}
487497

488498
if (para.train_mode == 0 || para.train_mode == 3) {
489-
printf(
490-
"%-8d%-11.5f%-11.5f%-11.5f%-13.5f%-13.5f%-13.5f%-13.5f%-13.5f%-13.5f\n",
491-
generation + 1,
492-
loss_total,
493-
loss_L1,
494-
loss_L2,
495-
rmse_energy_train,
496-
rmse_force_train,
497-
rmse_virial_train,
498-
rmse_energy_test,
499-
rmse_force_test,
500-
rmse_virial_test);
501-
fprintf(
502-
fid_loss_out,
503-
"%-8d%-11.5f%-11.5f%-11.5f%-13.5f%-13.5f%-13.5f%-13.5f%-13.5f%-13.5f\n",
504-
generation + 1,
505-
loss_total,
506-
loss_L1,
507-
loss_L2,
508-
rmse_energy_train,
509-
rmse_force_train,
510-
rmse_virial_train,
511-
rmse_energy_test,
512-
rmse_force_test,
513-
rmse_virial_test);
499+
if (!para.charge_mode) {
500+
// NEP models
501+
printf(
502+
"%-8d%-11.5f%-11.5f%-11.5f%-13.5f%-13.5f%-13.5f%-13.5f%-13.5f%-13.5f\n",
503+
generation + 1,
504+
loss_total,
505+
loss_L1,
506+
loss_L2,
507+
rmse_energy_train,
508+
rmse_force_train,
509+
rmse_virial_train,
510+
rmse_energy_test,
511+
rmse_force_test,
512+
rmse_virial_test);
513+
fprintf(
514+
fid_loss_out,
515+
"%-8d%-11.5f%-11.5f%-11.5f%-13.5f%-13.5f%-13.5f%-13.5f%-13.5f%-13.5f\n",
516+
generation + 1,
517+
loss_total,
518+
loss_L1,
519+
loss_L2,
520+
rmse_energy_train,
521+
rmse_force_train,
522+
rmse_virial_train,
523+
rmse_energy_test,
524+
rmse_force_test,
525+
rmse_virial_test);
526+
} else {
527+
// qNEP models:
528+
printf(
529+
"%-8d%-11.5f%-11.5f%-11.5f%-13.5f%-13.5f%-13.5f%-13.5f%-13.5f%-13.5f%-13.5f%-13.5f%-13.5f%-13.5f\n",
530+
generation + 1,
531+
loss_total,
532+
loss_L1,
533+
loss_L2,
534+
rmse_energy_train,
535+
rmse_force_train,
536+
rmse_virial_train,
537+
rmse_charge_train,
538+
rmse_bec_train,
539+
rmse_energy_test,
540+
rmse_force_test,
541+
rmse_virial_test,
542+
rmse_charge_test,
543+
rmse_bec_test);
544+
fprintf(
545+
fid_loss_out,
546+
"%-8d%-11.5f%-11.5f%-11.5f%-13.5f%-13.5f%-13.5f%-13.5f%-13.5f%-13.5f%-13.5f%-13.5f%-13.5f%-13.5f\n",
547+
generation + 1,
548+
loss_total,
549+
loss_L1,
550+
loss_L2,
551+
rmse_energy_train,
552+
rmse_force_train,
553+
rmse_virial_train,
554+
rmse_charge_train,
555+
rmse_bec_train,
556+
rmse_energy_test,
557+
rmse_force_test,
558+
rmse_virial_test,
559+
rmse_charge_test,
560+
rmse_bec_test);
561+
}
514562
} else {
563+
// TNEP models:
515564
printf(
516565
"%-8d%-11.5f%-11.5f%-11.5f%-13.5f%-13.5f\n",
517566
generation + 1,

src/main_nep/snes.cu

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -318,18 +318,37 @@ void SNES::compute(Parameters& para, Fitness* fitness_function)
318318
if (para.prediction == 0) {
319319

320320
if (para.train_mode == 0 || para.train_mode == 3) {
321-
printf(
322-
"%-8s%-11s%-11s%-11s%-13s%-13s%-13s%-13s%-13s%-13s\n",
323-
"Step",
324-
"Total-Loss",
325-
"L1Reg-Loss",
326-
"L2Reg-Loss",
327-
"RMSE-E-Train",
328-
"RMSE-F-Train",
329-
"RMSE-V-Train",
330-
"RMSE-E-Test",
331-
"RMSE-F-Test",
332-
"RMSE-V-Test");
321+
if (!para.charge_mode) {
322+
printf(
323+
"%-8s%-11s%-11s%-11s%-13s%-13s%-13s%-13s%-13s%-13s\n",
324+
"Step",
325+
"Total-Loss",
326+
"L1Reg-Loss",
327+
"L2Reg-Loss",
328+
"RMSE-E-Train",
329+
"RMSE-F-Train",
330+
"RMSE-V-Train",
331+
"RMSE-E-Test",
332+
"RMSE-F-Test",
333+
"RMSE-V-Test");
334+
} else {
335+
printf(
336+
"%-8s%-11s%-11s%-11s%-13s%-13s%-13s%-13s%-13s%-13s%-13s%-13s%-13s%-13s\n",
337+
"Step",
338+
"Total-Loss",
339+
"L1Reg-Loss",
340+
"L2Reg-Loss",
341+
"RMSE-E-Train",
342+
"RMSE-F-Train",
343+
"RMSE-V-Train",
344+
"RMSE-Q-Train",
345+
"RMSE-Z-Train",
346+
"RMSE-E-Test",
347+
"RMSE-F-Test",
348+
"RMSE-V-Test",
349+
"RMSE-Q-Test",
350+
"RMSE-Z-Test");
351+
}
333352
} else {
334353
printf(
335354
"%-8s%-11s%-11s%-11s%-13s%-13s\n",

0 commit comments

Comments
 (0)