@@ -149,7 +149,14 @@ Fitness::~Fitness()
149149}
150150
151151void Fitness::compute (
152- const int generation, Parameters& para, const float * population, float * fitness)
152+ const int generation,
153+ Parameters& para,
154+ const float * population,
155+ float * fitness_energy,
156+ float * fitness_force,
157+ float * fitness_virial,
158+ float * fitness_charge,
159+ float * fitness_bec)
153160{
154161 int deviceCount;
155162 CHECK (gpuGetDeviceCount (&deviceCount));
@@ -180,20 +187,19 @@ void Fitness::compute(
180187 auto rmse_force_array = train_set[batch_id][m].get_rmse_force (para, true , m);
181188 auto rmse_virial_array = train_set[batch_id][m].get_rmse_virial (para, true , m);
182189 auto rmse_charge_array = train_set[batch_id][m].get_rmse_charge (para, m);
190+ auto rmse_bec_array = train_set[batch_id][m].get_rmse_bec (para, m);
183191
184192 for (int t = 0 ; t <= para.num_types ; ++t) {
185- fitness [deviceCount * n + m + ( 7 * t + 3 ) * para.population_size ] =
193+ fitness_energy [deviceCount * n + m + t * para.population_size ] =
186194 para.lambda_e * rmse_energy_array[t];
187- fitness [deviceCount * n + m + ( 7 * t + 4 ) * para.population_size ] =
195+ fitness_force [deviceCount * n + m + t * para.population_size ] =
188196 para.lambda_f * rmse_force_array[t];
189- fitness [deviceCount * n + m + ( 7 * t + 5 ) * para.population_size ] =
197+ fitness_virial [deviceCount * n + m + t * para.population_size ] =
190198 para.lambda_v * rmse_virial_array[t];
191- if (para.charge_mode ) {
192- fitness[deviceCount * n + m + (7 * t + 6 ) * para.population_size ] =
193- para.lambda_q * rmse_charge_array[t];
194- } else {
195- fitness[deviceCount * n + m + (7 * t + 6 ) * para.population_size ] = 0 .0f ;
196- }
199+ fitness_charge[deviceCount * n + m + t * para.population_size ] =
200+ para.lambda_q * rmse_charge_array[t];
201+ fitness_bec[deviceCount * n + m + t * para.population_size ] =
202+ para.lambda_z * rmse_bec_array[t];
197203 }
198204 }
199205 }
@@ -216,33 +222,38 @@ void Fitness::compute(
216222 auto rmse_force_array = train_set[batch_id][m].get_rmse_force (para, true , m);
217223 auto rmse_virial_array = train_set[batch_id][m].get_rmse_virial (para, true , m);
218224 auto rmse_charge_array = train_set[batch_id][m].get_rmse_charge (para, m);
225+ auto rmse_bec_array = train_set[batch_id][m].get_rmse_bec (para, m);
219226 for (int t = 0 ; t <= para.num_types ; ++t) {
220227 // energy
221- float old_value = fitness [deviceCount * n + m + ( 7 * t + 3 ) * para.population_size ];
228+ float old_value = fitness_energy [deviceCount * n + m + t * para.population_size ];
222229 float new_value = para.lambda_e * rmse_energy_array[t];
223230 new_value = old_value * old_value * count_batch + new_value * new_value;
224231 new_value = sqrt (new_value / (count_batch + 1 ));
225- fitness [deviceCount * n + m + ( 7 * t + 3 ) * para.population_size ] = new_value;
232+ fitness_energy [deviceCount * n + m + t * para.population_size ] = new_value;
226233 // force
227- old_value = fitness [deviceCount * n + m + ( 7 * t + 4 ) * para.population_size ];
234+ old_value = fitness_force [deviceCount * n + m + t * para.population_size ];
228235 new_value = para.lambda_f * rmse_force_array[t];
229236 new_value = old_value * old_value * count_batch + new_value * new_value;
230237 new_value = sqrt (new_value / (count_batch + 1 ));
231- fitness [deviceCount * n + m + ( 7 * t + 4 ) * para.population_size ] = new_value;
238+ fitness_force [deviceCount * n + m + t * para.population_size ] = new_value;
232239 // virial
233- old_value = fitness [deviceCount * n + m + ( 7 * t + 5 ) * para.population_size ];
240+ old_value = fitness_virial [deviceCount * n + m + t * para.population_size ];
234241 new_value = para.lambda_v * rmse_virial_array[t];
235242 new_value = old_value * old_value * count_batch + new_value * new_value;
236243 new_value = sqrt (new_value / (count_batch + 1 ));
237- fitness [deviceCount * n + m + ( 7 * t + 5 ) * para.population_size ] = new_value;
244+ fitness_virial [deviceCount * n + m + t * para.population_size ] = new_value;
238245 // charge
239- if (para.charge_mode ) {
240- old_value = fitness[deviceCount * n + m + (7 * t + 6 ) * para.population_size ];
241- new_value = para.lambda_q * rmse_charge_array[t];
242- new_value = old_value * old_value * count_batch + new_value * new_value;
243- new_value = sqrt (new_value / (count_batch + 1 ));
244- fitness[deviceCount * n + m + (7 * t + 6 ) * para.population_size ] = new_value;
245- }
246+ old_value = fitness_charge[deviceCount * n + m + t * para.population_size ];
247+ new_value = para.lambda_q * rmse_charge_array[t];
248+ new_value = old_value * old_value * count_batch + new_value * new_value;
249+ new_value = sqrt (new_value / (count_batch + 1 ));
250+ fitness_charge[deviceCount * n + m + t * para.population_size ] = new_value;
251+ // BEC
252+ old_value = fitness_bec[deviceCount * n + m + t * para.population_size ];
253+ new_value = para.lambda_z * rmse_bec_array[t];
254+ new_value = old_value * old_value * count_batch + new_value * new_value;
255+ new_value = sqrt (new_value / (count_batch + 1 ));
256+ fitness_bec[deviceCount * n + m + t * para.population_size ] = new_value;
246257 }
247258 }
248259 }
@@ -435,10 +446,14 @@ void Fitness::report_error(
435446 train_set[batch_id][0 ].get_rmse_energy (para, energy_shift_per_structure, false , true , 0 );
436447 auto rmse_force_train_array = train_set[batch_id][0 ].get_rmse_force (para, false , 0 );
437448 auto rmse_virial_train_array = train_set[batch_id][0 ].get_rmse_virial (para, false , 0 );
449+ auto rmse_charge_train_array = train_set[batch_id][0 ].get_rmse_charge (para, 0 );
450+ auto rmse_bec_train_array = train_set[batch_id][0 ].get_rmse_bec (para, 0 );
438451
439452 float rmse_energy_train = rmse_energy_train_array.back ();
440453 float rmse_force_train = rmse_force_train_array.back ();
441454 float rmse_virial_train = rmse_virial_train_array.back ();
455+ float rmse_charge_train = rmse_charge_train_array.back ();
456+ float rmse_bec_train = rmse_bec_train_array.back ();
442457
443458 // correct the last bias parameter in the NN
444459 if (para.train_mode == 0 || para.train_mode == 3 ) {
@@ -448,16 +463,22 @@ void Fitness::report_error(
448463 float rmse_energy_test = 0 .0f ;
449464 float rmse_force_test = 0 .0f ;
450465 float rmse_virial_test = 0 .0f ;
466+ float rmse_charge_test = 0 .0f ;
467+ float rmse_bec_test = 0 .0f ;
451468 if (has_test_set) {
452469 potential->find_force (para, elite, test_set, false , true , 1 );
453470 float energy_shift_per_structure_not_used;
454471 auto rmse_energy_test_array =
455472 test_set[0 ].get_rmse_energy (para, energy_shift_per_structure_not_used, false , false , 0 );
456473 auto rmse_force_test_array = test_set[0 ].get_rmse_force (para, false , 0 );
457474 auto rmse_virial_test_array = test_set[0 ].get_rmse_virial (para, false , 0 );
475+ auto rmse_charge_test_array = test_set[0 ].get_rmse_charge (para, 0 );
476+ auto rmse_bec_test_array = test_set[0 ].get_rmse_bec (para, 0 );
458477 rmse_energy_test = rmse_energy_test_array.back ();
459478 rmse_force_test = rmse_force_test_array.back ();
460479 rmse_virial_test = rmse_virial_test_array.back ();
480+ rmse_charge_test = rmse_charge_test_array.back ();
481+ rmse_bec_test = rmse_bec_test_array.back ();
461482 }
462483
463484 FILE* fid_nep = my_fopen (" nep.txt" , " w" );
@@ -475,32 +496,71 @@ void Fitness::report_error(
475496 }
476497
477498 if (para.train_mode == 0 || para.train_mode == 3 ) {
478- printf (
479- " %-8d%-11.5f%-11.5f%-11.5f%-13.5f%-13.5f%-13.5f%-13.5f%-13.5f%-13.5f\n " ,
480- generation + 1 ,
481- loss_total,
482- loss_L1,
483- loss_L2,
484- rmse_energy_train,
485- rmse_force_train,
486- rmse_virial_train,
487- rmse_energy_test,
488- rmse_force_test,
489- rmse_virial_test);
490- fprintf (
491- fid_loss_out,
492- " %-8d%-11.5f%-11.5f%-11.5f%-13.5f%-13.5f%-13.5f%-13.5f%-13.5f%-13.5f\n " ,
493- generation + 1 ,
494- loss_total,
495- loss_L1,
496- loss_L2,
497- rmse_energy_train,
498- rmse_force_train,
499- rmse_virial_train,
500- rmse_energy_test,
501- rmse_force_test,
502- rmse_virial_test);
499+ if (!para.charge_mode ) {
500+ // NEP models
501+ printf (
502+ " %-8d%-11.5f%-11.5f%-11.5f%-13.5f%-13.5f%-13.5f%-13.5f%-13.5f%-13.5f\n " ,
503+ generation + 1 ,
504+ loss_total,
505+ loss_L1,
506+ loss_L2,
507+ rmse_energy_train,
508+ rmse_force_train,
509+ rmse_virial_train,
510+ rmse_energy_test,
511+ rmse_force_test,
512+ rmse_virial_test);
513+ fprintf (
514+ fid_loss_out,
515+ " %-8d%-11.5f%-11.5f%-11.5f%-13.5f%-13.5f%-13.5f%-13.5f%-13.5f%-13.5f\n " ,
516+ generation + 1 ,
517+ loss_total,
518+ loss_L1,
519+ loss_L2,
520+ rmse_energy_train,
521+ rmse_force_train,
522+ rmse_virial_train,
523+ rmse_energy_test,
524+ rmse_force_test,
525+ rmse_virial_test);
526+ } else {
527+ // qNEP models:
528+ printf (
529+ " %-8d%-9.5f%-9.5f%-9.5f%-9.5f%-9.5f%-9.5f%-9.5f%-9.5f%-9.5f%-9.5f%-9.5f%-9.5f%-9.5f\n " ,
530+ generation + 1 ,
531+ loss_total,
532+ loss_L1,
533+ loss_L2,
534+ rmse_energy_train,
535+ rmse_force_train,
536+ rmse_virial_train,
537+ rmse_charge_train,
538+ rmse_bec_train,
539+ rmse_energy_test,
540+ rmse_force_test,
541+ rmse_virial_test,
542+ rmse_charge_test,
543+ rmse_bec_test);
544+ fprintf (
545+ fid_loss_out,
546+ " %-8d%-9.5f%-9.5f%-9.5f%-9.5f%-9.5f%-9.5f%-9.5f%-9.5f%-9.5f%-9.5f%-9.5f%-9.5f%-9.5f\n " ,
547+ generation + 1 ,
548+ loss_total,
549+ loss_L1,
550+ loss_L2,
551+ rmse_energy_train,
552+ rmse_force_train,
553+ rmse_virial_train,
554+ rmse_charge_train,
555+ rmse_bec_train,
556+ rmse_energy_test,
557+ rmse_force_test,
558+ rmse_virial_test,
559+ rmse_charge_test,
560+ rmse_bec_test);
561+ }
503562 } else {
563+ // TNEP models:
504564 printf (
505565 " %-8d%-11.5f%-11.5f%-11.5f%-13.5f%-13.5f\n " ,
506566 generation + 1 ,
0 commit comments