Skip to content

Commit 3e29db4

Browse files
committed
aug / fiddle + save
1 parent b78750d commit 3e29db4

3 files changed

Lines changed: 39 additions & 3 deletions

File tree

bioencoder/core/augmentations.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ def get_transforms(config, valid=False):
1515
"""
1616
default_size = 224
1717
img_size = config.get('img_size', default_size)
18-
aug = get_aug_from_config(config.get('augmentations', {}).get('transforms', []))
18+
config_aug = config.get('augmentations', {})
19+
aug = get_aug_from_config(config_aug.get('transforms', []))
1920

2021
return A.Compose([
2122
A.Resize(img_size, img_size, always_apply=True),
@@ -64,3 +65,36 @@ def get_aug_from_config(config):
6465
return getattr(A, name)(*args, **config)
6566

6667

68+
# def get_aug_from_config2(config_aug):
69+
# """
70+
# A helper function to create image augmentation pipeline based on a given config_auguration.
71+
72+
# Parameters:
73+
# config_aug (str, list, or dict): A string, list of strings, or dictionary representing the augmentation pipeline.
74+
75+
# Returns:
76+
# aug (albumentations.augmentations.transforms): The constructed augmentation pipeline.
77+
# """
78+
# config_aug = copy.deepcopy(config_aug)
79+
80+
# if config_aug is None:
81+
# return A.NoOp()
82+
83+
# elif isinstance(config_aug, dict):
84+
85+
# compose = config_aug.get("compose", {
86+
# "name": "Sequential",
87+
# "p1": 1})
88+
89+
# transforms_inst = getattr(A, compose["name"])
90+
# transforms_inst(parse_transforms(config_aug.get("transforms", {})), p=compose["p1"])
91+
92+
93+
# def parse_transforms(transforms):
94+
# transforms_list = []
95+
# for trans in transforms:
96+
# name = list(trans.keys())[0]
97+
# params = trans.get(name, {})
98+
# transforms_list.append(getattr(A, name)(**params))
99+
100+
# return transforms_list

bioencoder/core/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import importlib
33
import random
44
import os
5+
import shutil
56
import numpy as np
67
import yaml
78
from collections import defaultdict
@@ -643,6 +644,7 @@ def save_augmented_sample(data_dir, transform, n_samples, seed):
643644
# Load dataset
644645
dataset = ImageFolder(root=os.path.join(data_dir, "train"))
645646
save_dir = os.path.join(data_dir, "aug_sample")
647+
shutil.rmtree(save_dir)
646648
os.makedirs(save_dir, exist_ok=True)
647649

648650
## reverse image net transforms

bioencoder/scripts/train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,9 @@ def train(
8484
"valid_batch_size": hyperparams["dataloaders"]["valid_batch_size"],
8585
}
8686
num_workers = hyperparams["dataloaders"]["num_workers"]
87-
aug_sample = hyperparams["augmentations"].get("save_sample", False)
87+
aug_sample = hyperparams["augmentations"].get("sample_save", False)
8888
aug_sample_n = hyperparams["augmentations"].get("sample_n", 5)
89-
aug_sample_seed = hyperparams["augmentations"].get("seed", 42)
89+
aug_sample_seed = hyperparams["augmentations"].get("sample_seed", 42)
9090

9191
## manage directories and paths
9292
data_dir = os.path.join(root_dir, "data", run_name)

0 commit comments

Comments
 (0)