@@ -75,6 +75,7 @@ def train(
7575 ema_decay_per_epoch = hyperparams ["train" ]["ema_decay_per_epoch" ]
7676 n_epochs = hyperparams ["train" ]["n_epochs" ]
7777 target_metric = hyperparams ["train" ]["target_metric" ]
78+ min_improvement = hyperparams ["train" ].get ("min_improvement" , 0.01 )
7879 stage = hyperparams ["train" ]["stage" ]
7980 optimizer_params = hyperparams ["optimizer" ]
8081 scheduler_params = hyperparams ["scheduler" ]
@@ -319,15 +320,14 @@ def train(
319320 pass
320321
321322 # check if the best value of metric changed. If so -> save the model
322- if (
323- valid_metrics [target_metric ] > metric_best * 0.99
324- ): # > 0 if wanting to save all models
323+ current_metric = valid_metrics [target_metric ]
324+ if metric_best == 0 or current_metric > metric_best * (1 + min_improvement ):
325325 logger .info (
326- "{} increased ({:.6f} --> {:.6f}). Saving model ..." .format (
327- target_metric , metric_best , valid_metrics [ target_metric ]
326+ "{} improved by ≥{:.2%} ({:.6f} --> {:.6f}). Saving model ..." .format (
327+ target_metric , min_improvement , metric_best , current_metric
328328 )
329329 )
330-
330+
331331 torch .save (
332332 {
333333 "epoch" : epoch ,
@@ -336,8 +336,10 @@ def train(
336336 },
337337 os .path .join (weights_dir , f"epoch{ epoch } " ),
338338 )
339- metric_best = valid_metrics [target_metric ]
340-
339+ metric_best = current_metric
340+ else :
341+ logger .info (f"Metric { target_metric } did not improve by ≥{ min_improvement :.2%} (best: { metric_best :.6f} , current: { current_metric :.6f} )" )
342+
341343 # if ema is used, go back to regular weights without ema
342344 if ema :
343345 utils .copy_parameters_to_model (copy_of_model_parameters , model )
0 commit comments