|
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | 15 | import os |
| 16 | +import gc |
16 | 17 | import time |
17 | 18 | import yaml |
18 | 19 | import json |
@@ -367,13 +368,15 @@ def train(model, |
367 | 368 | os.path.join(current_save_dir, 'model.pdopt')) |
368 | 369 | if uniform_output_enabled: |
369 | 370 | export(cli_args, model, current_save_dir) |
| 371 | + gc.collect() |
370 | 372 |
|
371 | 373 | if use_ema: |
372 | 374 | paddle.save( |
373 | 375 | ema_model.state_dict(), |
374 | 376 | os.path.join(current_save_dir, 'ema_model.pdparams')) |
375 | 377 | if uniform_output_enabled: |
376 | 378 | export(cli_args, ema_model, current_save_dir, use_ema) |
| 379 | + gc.collect() |
377 | 380 |
|
378 | 381 | save_models.append(current_save_dir) |
379 | 382 | if len(save_models) > keep_checkpoint_max > 0: |
@@ -405,6 +408,7 @@ def train(model, |
405 | 408 | os.path.join(best_model_dir, 'model.pdstates')) |
406 | 409 | if uniform_output_enabled: |
407 | 410 | export(cli_args, model, best_model_dir) |
| 411 | + gc.collect() |
408 | 412 | save_model_info(states_dict, best_model_dir) |
409 | 413 | update_train_results(cli_args, |
410 | 414 | "best_model", |
@@ -450,6 +454,7 @@ def train(model, |
450 | 454 | if uniform_output_enabled: |
451 | 455 | export(cli_args, ema_model, best_ema_model_dir, |
452 | 456 | use_ema) |
| 457 | + gc.collect() |
453 | 458 | save_model_info(ema_states_dict, |
454 | 459 | best_ema_model_dir) |
455 | 460 | update_train_results(cli_args, |
|
0 commit comments