Skip to content

Commit a3d4b34

Browse files
authored
Merge pull request brucefan1983#1142 from brucefan1983/dump-nep-restart-periodically
Dump nep restart periodically
2 parents 1d9c56f + fe5cf24 commit a3d4b34

File tree

7 files changed

+40
-18
lines changed

7 files changed

+40
-18
lines changed

doc/nep/input_parameters/save_potential.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,11 @@ This keyword sets the number of of generations between writing a ``nep.txt`` che
99
If ``<format>`` is set to ``0`` the output file name is formatted as ``nep_gen[generation].txt``.
1010
If ``<format>`` is set to ``1`` the output file name is formatted as ``nep_y[year]_m[month]_d[day]_h[hour]_m[minute]_s[second]_generation[generation].txt``.
1111
These model files can be used to monitor the training progress of your model.
12+
Additionally, if ``<save_restart>`` is set to ``1`` the :ref:`nep.restart file <nep_restart>` is also written, following the naming format set by ``<format>``.
1213
Note that the :ref:`nep.restart file <nep_restart>` is the file that is required to continue training.
1314

1415
The syntax is::
1516

16-
save_potential <number_of_generations_between_save_potential> <format>
17+
save_potential <number_of_generations_between_save_potential> <format> <save_restart>
1718

1819
The default number of generations between saved model files is :math:`N=10^5` and the output file names use the extended format.

src/main_nep/fitness.cu

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,19 @@ void Fitness::write_nep_txt(FILE* fid_nep, Parameters& para, float* elite)
415415
}
416416
}
417417

418+
void Fitness::get_save_potential_label(Parameters& para, const int generation, std::string& label) {
419+
if (para.save_potential_format == 1) {
420+
time_t rawtime;
421+
time(&rawtime);
422+
struct tm* timeinfo = localtime(&rawtime);
423+
char buffer[200];
424+
strftime(buffer, sizeof(buffer), "nep_y%Y_m%m_d%d_h%H_m%M_s%S_generation", timeinfo);
425+
label = std::string(buffer) + std::to_string(generation + 1);
426+
} else {
427+
label = "nep_gen" + std::to_string(generation + 1);
428+
}
429+
}
430+
418431
void Fitness::report_error(
419432
Parameters& para,
420433
const int generation,
@@ -462,16 +475,8 @@ void Fitness::report_error(
462475

463476
if (0 == (generation + 1) % para.save_potential) {
464477
std::string filename;
465-
if (para.save_potential_format == 1) {
466-
time_t rawtime;
467-
time(&rawtime);
468-
struct tm* timeinfo = localtime(&rawtime);
469-
char buffer[200];
470-
strftime(buffer, sizeof(buffer), "nep_y%Y_m%m_d%d_h%H_m%M_s%S_generation", timeinfo);
471-
filename = std::string(buffer) + std::to_string(generation + 1) + ".txt";
472-
} else {
473-
filename = "nep_gen" + std::to_string(generation + 1) + ".txt";
474-
}
478+
get_save_potential_label(para, generation, filename);
479+
filename += ".txt";
475480

476481
FILE* fid_nep = my_fopen(filename.c_str(), "w");
477482
write_nep_txt(fid_nep, para, elite);

src/main_nep/fitness.cuh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ public:
3737
const float loss_L2,
3838
float* elite);
3939
void predict(Parameters& para, float* elite);
40+
void get_save_potential_label(Parameters& para, const int generation, std::string& filename);
4041

4142
protected:
4243
bool has_test_set = false;

src/main_nep/parameters.cu

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1316,8 +1316,8 @@ void Parameters::parse_save_potential(const char** param, int num_param)
13161316
{
13171317
is_save_potential_set = true;
13181318

1319-
if (num_param != 3) {
1320-
PRINT_INPUT_ERROR("save_potential should have 2 parameters.\n");
1319+
if (num_param != 4) {
1320+
PRINT_INPUT_ERROR("save_potential should have 3 parameters.\n");
13211321
}
13221322
if (!is_valid_int(param[1], &save_potential)) {
13231323
PRINT_INPUT_ERROR("save_potential interval should be an integer.\n");
@@ -1331,4 +1331,10 @@ void Parameters::parse_save_potential(const char** param, int num_param)
13311331
if (save_potential_format != 0 && save_potential_format != 1) {
13321332
PRINT_INPUT_ERROR("save_potential format should be 0 or 1.");
13331333
}
1334+
if (!is_valid_int(param[3], &save_potential_restart)) {
1335+
PRINT_INPUT_ERROR("save_potential save restart should be an integer.\n");
1336+
}
1337+
if (save_potential_restart != 0 && save_potential_restart != 1) {
1338+
PRINT_INPUT_ERROR("save_potential save restart should be 0 or 1.");
1339+
}
13341340
}

src/main_nep/parameters.cuh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ public:
3131
int population_size; // population size for SNES
3232
int maximum_generation; // maximum number of generations for SNES;
3333
int save_potential; // number of generations between writing a checkpoint nep.txt file.
34-
int save_potential_format; // format of checkpoint nep.txt file name
34+
int save_potential_format; // format of checkpoint nep.txt file name
35+
int save_potential_restart; // if restart files should be written or not. 0=no, 1=yes
3536
int num_neurons1; // number of nuerons in the 1st hidden layer (only one hidden layer)
3637
int basis_size_radial; // for nep3
3738
int basis_size_angular; // for nep3

src/main_nep/snes.cu

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,15 @@ void SNES::compute(Parameters& para, Fitness* fitness_function)
356356

357357
update_mu_and_sigma(para);
358358
if (0 == (n + 1) % 100) {
359-
output_mu_and_sigma(para);
359+
const char* filename = "nep.restart";
360+
output_mu_and_sigma(para, filename);
361+
}
362+
// Optionally save the nep.restart file at the same time as save_potential
363+
if (0 == (n + 1) % para.save_potential && para.save_potential_restart) {
364+
std::string restart_file;
365+
fitness_function->get_save_potential_label(para, n, restart_file);
366+
restart_file += ".restart";
367+
output_mu_and_sigma(para, restart_file.c_str());
360368
}
361369
}
362370
} else {
@@ -642,12 +650,12 @@ void SNES::update_mu_and_sigma(Parameters& para)
642650
GPU_CHECK_KERNEL;
643651
}
644652

645-
void SNES::output_mu_and_sigma(Parameters& para)
653+
void SNES::output_mu_and_sigma(Parameters& para, const char* filename)
646654
{
647655
gpuSetDevice(0); // normally use GPU-0
648656
gpu_mu.copy_to_host(mu.data());
649657
gpu_sigma.copy_to_host(sigma.data());
650-
FILE* fid_restart = my_fopen("nep.restart", "w");
658+
FILE* fid_restart = my_fopen(filename, "w");
651659
for (int n = 0; n < number_of_variables; ++n) {
652660
fprintf(fid_restart, "%15.7e %15.7e\n", mu[n], sigma[n]);
653661
}

src/main_nep/snes.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,5 +70,5 @@ protected:
7070
void regularize_NEP4(Parameters& para);
7171
void sort_population(Parameters& para);
7272
void update_mu_and_sigma(Parameters& para);
73-
void output_mu_and_sigma(Parameters& para);
73+
void output_mu_and_sigma(Parameters& para, const char* filename);
7474
};

0 commit comments

Comments
 (0)