Skip to content

Commit 9c04ac3

Browse files
committed
better val help/fix model improve check
1 parent a189795 commit 9c04ac3

2 files changed

Lines changed: 14 additions & 12 deletions

File tree

bioencoder/scripts/split_dataset.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,10 @@ def split_dataset(
3333
image_dir : str
3434
Path to the directory containing subfolders of images, where each subfolder represents a class.
3535
mode : str, optional
36-
Specifies the strategy for splitting the dataset:
37-
- "flat": Calculating split to the most abundant class (after applying max_ratio), and then applying it to all classes
38-
- "random": Randomly selects images across all classes to form the validation set, disregarding class balance.
39-
- "fixed": Ensures each class contributes a fixed proportion to the validation set, based on `val_percent`.
36+
Strategy for populating the validation subset:
37+
- "flat": Derives a single validation quota from the capped largest class (max_ratio), and applies it uniformly to all classes.
38+
- "random": Builds the validation set by drawing images uniformly at random from the pooled, balanced dataset, ignoring class membership.
39+
- "fixed": Assigns each class its own `val_percent` share to the validation set based on its balanced size.
4040
Default is "flat".
4141
val_percent : float, optional
4242
Proportion of the dataset to allocate to the validation set, expressed as a decimal.

bioencoder/scripts/train.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def train(
7575
ema_decay_per_epoch = hyperparams["train"]["ema_decay_per_epoch"]
7676
n_epochs = hyperparams["train"]["n_epochs"]
7777
target_metric = hyperparams["train"]["target_metric"]
78+
min_improvement = hyperparams["train"].get("min_improvement", 0.01)
7879
stage = hyperparams["train"]["stage"]
7980
optimizer_params = hyperparams["optimizer"]
8081
scheduler_params = hyperparams["scheduler"]
@@ -319,15 +320,14 @@ def train(
319320
pass
320321

321322
# check if the best value of metric changed. If so -> save the model
322-
if (
323-
valid_metrics[target_metric] > metric_best*0.99
324-
): # > 0 if wanting to save all models
323+
current_metric = valid_metrics[target_metric]
324+
if metric_best == 0 or current_metric > metric_best * (1 + min_improvement):
325325
logger.info(
326-
"{} increased ({:.6f} --> {:.6f}). Saving model ...".format(
327-
target_metric, metric_best, valid_metrics[target_metric]
326+
"{} improved by ≥{:.2%} ({:.6f} --> {:.6f}). Saving model ...".format(
327+
target_metric, min_improvement, metric_best, current_metric
328328
)
329329
)
330-
330+
331331
torch.save(
332332
{
333333
"epoch": epoch,
@@ -336,8 +336,10 @@ def train(
336336
},
337337
os.path.join(weights_dir, f"epoch{epoch}"),
338338
)
339-
metric_best = valid_metrics[target_metric]
340-
339+
metric_best = current_metric
340+
else:
341+
logger.info(f"Metric {target_metric} did not improve by ≥{min_improvement:.2%} (best: {metric_best:.6f}, current: {current_metric:.6f})")
342+
341343
# if ema is used, go back to regular weights without ema
342344
if ema:
343345
utils.copy_parameters_to_model(copy_of_model_parameters, model)

0 commit comments

Comments
 (0)