Skip to content

Commit f0713e8

Browse files
committed
interactive plotting (bs not multiple)
1 parent c9c8500 commit f0713e8

7 files changed

Lines changed: 111 additions & 102 deletions

File tree

bioencoder/core/utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -227,13 +227,13 @@ def build_loaders(data_dir, transforms, batch_sizes, num_workers, second_stage=F
227227
second_stage=True
228228
)
229229

230-
train_features_loader = torch.utils.data.DataLoader(
230+
train_loader = torch.utils.data.DataLoader(
231231
train_features_dataset,
232232
batch_size=batch_sizes['train_batch_size'],
233233
shuffle=True,
234234
num_workers=num_workers,
235235
pin_memory=True,
236-
drop_last=True
236+
drop_last=(batch_sizes['train_batch_size'] is not None)
237237
)
238238

239239
valid_loader = torch.utils.data.DataLoader(
@@ -242,11 +242,11 @@ def build_loaders(data_dir, transforms, batch_sizes, num_workers, second_stage=F
242242
shuffle=False,
243243
num_workers=num_workers,
244244
pin_memory=True,
245-
drop_last=True
245+
drop_last=(batch_sizes['valid_batch_size'] is not None)
246246
)
247247

248248
loaders = {
249-
'train_features_loader': train_features_loader,
249+
'train_loader': train_loader,
250250
'valid_loader': valid_loader
251251
}
252252

bioencoder/scripts/interactive_plots.py

Lines changed: 64 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -46,94 +46,90 @@ def interactive_plots(
4646
4747
"""
4848

49-
## load bioencoer config
50-
root_dir = config.root_dir
51-
run_name = config.run_name
52-
53-
## load config
49+
## Load Bioencoder config
50+
root_dir, run_name = config.root_dir, config.run_name
5451
hyperparams = utils.load_yaml(config_path)
5552

56-
## parse config
53+
## Parse config
5754
backbone = hyperparams["model"]["backbone"]
5855
num_classes = hyperparams["model"].get("num_classes", None)
5956
checkpoint = hyperparams["model"].get("checkpoint", "swa")
60-
stage = hyperparams["model"].get("stage", "first")
57+
stage = hyperparams.get("model", {}).get("stage", "first")
58+
6159
batch_sizes = {
62-
"train_batch_size": hyperparams["dataloaders"]["train_batch_size"],
63-
"valid_batch_size": hyperparams["dataloaders"]["valid_batch_size"],
60+
"train_batch_size": hyperparams.get("dataloaders", {}).get("train_batch_size"),
61+
"valid_batch_size": hyperparams.get("dataloaders", {}).get("valid_batch_size",1),
6462
}
65-
num_workers = hyperparams["dataloaders"]["num_workers"]
66-
color_classes = hyperparams.get("color_classes", None)
67-
color_map = hyperparams.get("color_map", "jet")
68-
plot_style = hyperparams.get("plot_style", 1)
69-
point_size = hyperparams.get("point_size", 10)
70-
perplexity = hyperparams.get("perplexity", None)
63+
num_workers = hyperparams.get("dataloaders", {}).get("num_workers", 4)
64+
perplexity = hyperparams.get("perplexity", 30)
7165

72-
## set up dirs
73-
data_dir = os.path.join(root_dir,"data", run_name)
74-
plot_dir = os.path.join(root_dir, "plots", run_name)
75-
os.makedirs(plot_dir, exist_ok=True)
66+
plot_config = {
67+
"color_classes": hyperparams.get("color_classes", None),
68+
"color_map": hyperparams.get("color_map", "jet"),
69+
"plot_style": hyperparams.get("plot_style", 1),
70+
"point_size": hyperparams.get("point_size", 10),
71+
}
72+
7673

77-
## plot path
78-
plot_path = os.path.join(plot_dir, f"embeddings_{run_name}.html")
74+
## Set up directories
75+
data_dir = os.path.join(root_dir, "data", run_name)
76+
plot_path = os.path.join(root_dir, "plots", run_name, f"embeddings_{run_name}.html")
7977
if not overwrite and not kwargs.get("ret_embeddings"):
8078
assert not os.path.isfile(plot_path), f"File exists: {plot_path}"
8179

82-
## load weights
80+
## Load model and set up
8381
print(f"Checkpoint: using {checkpoint} of {stage} stage")
84-
ckpt_pretrained = os.path.join(config.root_dir, "weights", run_name, stage, checkpoint)
85-
86-
## set random seed
82+
ckpt_pretrained = os.path.join(root_dir, "weights", run_name, stage, checkpoint)
8783
utils.set_seed()
88-
89-
## extract embeddings
9084
transforms = utils.build_transforms(hyperparams)
91-
loaders = utils.build_loaders(
92-
data_dir, transforms, batch_sizes, num_workers, second_stage=(stage == "second")
93-
)
94-
model = utils.build_model(
95-
backbone,
96-
second_stage=(stage == "second"),
97-
num_classes=num_classes,
98-
ckpt_pretrained=ckpt_pretrained,
99-
).cuda()
85+
loaders = utils.build_loaders(data_dir, transforms, batch_sizes, num_workers, second_stage=(stage == "second"))
86+
model = utils.build_model(backbone, second_stage=(stage == "second"), num_classes=num_classes, ckpt_pretrained=ckpt_pretrained).cuda()
10087
model.use_projection_head(False)
10188
model.eval()
102-
embeddings_train, labels_train = utils.compute_embeddings(
103-
loaders["valid_loader"], model
104-
)
105-
106-
## load dataset
107-
rel_paths_train = [item[0][len(root_dir) + 1:] for item in loaders["valid_loader"].dataset.imgs]
108-
109-
## return embeddings without plotting
89+
90+
## Determine which embeddings to compute
91+
embeddings, labels, rel_paths = [], [], []
92+
93+
## val batch size cant be zero
94+
embeddings_val, labels_val = utils.compute_embeddings(loaders["valid_loader"], model)
95+
if len(embeddings_val) < len(loaders["valid_loader"].dataset.imgs):
96+
missed_imgs = len(loaders["valid_loader"].dataset.imgs) - len(embeddings_val)
97+
print(f"Warning: missed {missed_imgs} images because batch size was not a multiple of validation dataset size.")
98+
rel_paths_val = [item[0][len(root_dir) + 1:] for item in loaders["valid_loader"].dataset.imgs[:len(embeddings_val)]]
99+
embeddings.extend(embeddings_val)
100+
labels.extend(labels_val)
101+
rel_paths.extend(rel_paths_val)
102+
103+
## train set embeddings
104+
if batch_sizes["train_batch_size"] is not None:
105+
embeddings_train, labels_train = utils.compute_embeddings(loaders["train_loader"], model)
106+
if len(embeddings_train) < len(loaders["train_loader"].dataset.imgs):
107+
missed_imgs = len(loaders["train_loader"].dataset.imgs) - len(embeddings_train)
108+
print(f"Warning: missed {missed_imgs} images because batch size was not a multiple of training dataset size.")
109+
rel_paths_train = [item[0][len(root_dir) + 1:] for item in loaders["train_loader"].dataset.imgs[:len(embeddings_train)]]
110+
embeddings.extend(embeddings_train)
111+
labels.extend(labels_train)
112+
rel_paths.extend(rel_paths_train)
113+
114+
## Return embeddings without plotting
110115
if kwargs.get("ret_embeddings"):
116+
df = pd.DataFrame({"image_name": [os.path.basename(p) for p in rel_paths], "class": [os.path.basename(os.path.dirname(p)) for p in rel_paths]})
117+
return pd.concat([df, pd.DataFrame(embeddings)], axis=1)
111118

112-
df = pd.DataFrame([os.path.basename(item) for item in rel_paths_train], columns=["image_name"])
113-
df["class"] = [
114-
os.path.basename(os.path.dirname(item[0])) for item in loaders["valid_loader"].dataset.imgs
115-
]
116-
return pd.concat([df, pd.DataFrame(embeddings_train)], axis=1)
117-
118-
## reduce dimensionality
119-
perplexity = perplexity if perplexity else min(100, len(embeddings_train) // 2)
120-
reduced_data, colnames, _ = helpers.embbedings_dimension_reductions(
121-
embeddings_train, perplexity
122-
)
123-
df = pd.DataFrame(reduced_data, columns=colnames)
124-
df["paths"] = [ os.path.join("..", "..", item) for item in rel_paths_train]
125-
df["class"] = labels_train
126-
df["class_str"] = [
127-
os.path.basename(os.path.dirname(item[0])) for item in loaders["valid_loader"].dataset.imgs
128-
]
129-
130-
## check if color matches n classes
131-
if color_classes:
132-
assert len(np.unique(labels_train)) == len(color_classes), f"Number of classes is {len(np.unique(labels_train))}, but you only provided {len(color_classes)} colors"
133-
134-
helpers.bokeh_plot(df, out_path=plot_path, color_map=color_map, color_classes=color_classes,
135-
plot_style=plot_style, point_size=point_size)
119+
## Reduce dimensionality
120+
if not perplexity:
121+
perplexity = min(100, len(embeddings) // 2)
122+
print(f"tSNE: using a perplexity value of {perplexity}")
123+
reduced_data, colnames, _ = helpers.embbedings_dimension_reductions(embeddings, perplexity)
136124

125+
## make plot
126+
df = pd.DataFrame(reduced_data, columns=colnames)
127+
df["paths"] = [os.path.join("..", "..", p) for p in rel_paths]
128+
df["class"], df["class_str"] = labels, [os.path.basename(os.path.dirname(p)) for p in rel_paths]
129+
df["dataset"] = df["paths"].apply(lambda x: "validation" if "/val/" in x else "train")
130+
131+
helpers.bokeh_plot(df, out_path=plot_path, **plot_config)
132+
137133

138134
def cli():
139135

bioencoder/scripts/lr_finder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def lr_finder(
118118
optim["scheduler"],
119119
)
120120
lr_finder = LRFinder(model, optimizer, criterion, device="cuda")
121-
lr_finder.range_test(loaders["train_features_loader"], end_lr=1, num_iter=num_iter)
121+
lr_finder.range_test(loaders["train_loader"], end_lr=1, num_iter=num_iter)
122122

123123
fig, ax = plt.subplots()
124124
ax, lr = lr_finder.plot(ax=ax, skip_start=skip_start, skip_end=skip_end)

bioencoder/scripts/split_dataset.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def split_dataset(
1616
image_dir,
1717
mode="flat",
1818
val_percent=0.1,
19-
max_ratio=7,
19+
max_ratio=10,
2020
min_per_class=20,
2121
random_seed=42,
2222
dry_run=False,
@@ -205,11 +205,11 @@ def cli():
205205
parser = argparse.ArgumentParser()
206206
parser.add_argument("--image-dir", type=str, help="Path to the images directory sorted into class-specific subfolders.")
207207
parser.add_argument("--mode", type=str, choices=['flat', 'random', 'fixed'], default='flat', help="Type of dataset split to perform.")
208-
parser.add_argument("--val_percent", type=float, default=0.1, help="Percentage of data to use as validation set.")
209-
parser.add_argument("--max_ratio", type=int, default=7, help="Maximum ratio between the most and least abundant classes.")
210-
parser.add_argument("--min_per_class", type=int, default=20, help="Minimum number of images per class.")
211-
parser.add_argument("--random_seed", type=int, default=42, help="Seed for random number generator.")
212-
parser.add_argument("--dry_run", action='store_true', help="Run without making any changes.")
208+
parser.add_argument("--val-percent", type=float, default=0.1, help="Percentage of data to use as validation set.")
209+
parser.add_argument("--max-ratio", type=int, default=7, help="Maximum ratio between the most and least abundant classes.")
210+
parser.add_argument("--min-per-class", type=int, default=20, help="Minimum number of images per class.")
211+
parser.add_argument("--random-seed", type=int, default=42, help="Seed for random number generator.")
212+
parser.add_argument("--dry-run", action='store_true', help="Run without making any changes.")
213213
parser.add_argument("--overwrite", action='store_true', help="Overwrite existing files without asking.")
214214
args = parser.parse_args()
215215

bioencoder/scripts/swa.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def swa(
112112

113113
if stage == "first":
114114
valid_metrics = utils.validation_constructive(
115-
loaders["valid_loader"], loaders["train_features_loader"], model, scaler
115+
loaders["valid_loader"], loaders["train_loader"], model, scaler
116116
)
117117
else:
118118
valid_metrics = utils.validation_ce(

bioencoder/scripts/train.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def train(
219219
optim["loss_optimizer"],
220220
)
221221
if ema:
222-
iters = len(loaders["train_features_loader"])
222+
iters = len(loaders["train_loader"])
223223
ema_decay = ema_decay_per_epoch ** (1 / iters)
224224
ema = ExponentialMovingAverage(model.parameters(), decay=ema_decay)
225225

@@ -244,7 +244,7 @@ def train(
244244
)
245245
else:
246246
train_metrics = utils.train_epoch_ce(
247-
loaders["train_features_loader"],
247+
loaders["train_loader"],
248248
model,
249249
criterion,
250250
optimizer,
@@ -261,7 +261,7 @@ def train(
261261

262262
if stage == "first":
263263
valid_metrics_projection_head = utils.validation_constructive(
264-
loaders["valid_loader"], loaders["train_features_loader"], model, scaler
264+
loaders["valid_loader"], loaders["train_loader"], model, scaler
265265
)
266266

267267
## check for GPU parallelization
@@ -270,7 +270,7 @@ def train(
270270
#model_copy.use_projection_head(False)
271271
model.use_projection_head(False)
272272
valid_metrics_encoder = utils.validation_constructive(
273-
loaders["valid_loader"], loaders["train_features_loader"], model, scaler
273+
loaders["valid_loader"], loaders["train_loader"], model, scaler
274274
)
275275
model.use_projection_head(True)
276276
#model_copy.use_projection_head(True) parser.add_argument("--dry_run", action='store_true', help="Run without making any changes.")

bioencoder/vis/helpers.py

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,25 @@
1+
2+
#%% imports
3+
14
import numpy as np
25
import pandas as pd
36
import cv2
47
import torch
58
from torchvision import transforms
69
from sklearn import decomposition, manifold
710
import matplotlib.pyplot as plt
8-
from bokeh.models import (LassoSelectTool, PanTool,
9-
ResetTool, Div, CustomJS,
10-
HoverTool, WheelZoomTool)
11-
TOOLS = [LassoSelectTool, PanTool, WheelZoomTool, ResetTool]
12-
from bokeh.models import ColumnDataSource
11+
from bokeh.models import (LassoSelectTool, PanTool,ResetTool, Div, CustomJS, HoverTool, WheelZoomTool,
12+
ColumnDataSource, Legend, LegendItem)
1313
from bokeh import plotting as bplot
14+
from bokeh.transform import factor_mark
1415
from bokeh.io import show
1516
from bokeh import layouts
1617

18+
TOOLS = [LassoSelectTool, PanTool, WheelZoomTool, ResetTool]
19+
20+
21+
#%% functions
22+
1723

1824
def preprocess_image(img):
1925
"""
@@ -236,7 +242,7 @@ def embbedings_dimension_reductions(data_table, perplexity):
236242
return np.hstack((pca, tsne)), names, pca_obj
237243

238244

239-
def bokeh_plot(df, out_path='plot.html', color_map="jet", color_classes=None, plot_style=1,
245+
def bokeh_plot(df, out_path='plot.html', color_map="viridis", color_classes=None, plot_style=1,
240246
point_size=10, **kwargs):
241247
"""
242248
Plot a scatter plot of the PCA and t-SNE dimensions of the data using bokeh.
@@ -256,18 +262,25 @@ class labels of the images).
256262

257263
if not all(col in df.columns for col in ['paths', 'class']):
258264
raise ValueError("The dataframe must have columns 'paths' and 'class'")
259-
260-
df['image_files'] = df['paths']
261-
262-
## color management
265+
266+
unique_classes = df['class'].unique()
267+
268+
269+
## Color management
263270
if color_classes:
271+
assert len(unique_classes) == len(color_classes), (
272+
f"Number of classes is {len(unique_classes)}, but only {len(color_classes)} colors provided."
273+
)
274+
275+
# Convert dict to DataFrame and merge colors
264276
df_col = pd.DataFrame.from_dict(color_classes.items())
265-
df_col.columns = ["class_str","color"]
266-
df = df.merge(df_col)
277+
df_col.columns = ["class_str", "color"]
278+
df = df.merge(df_col, how="left", left_on="class", right_on="class_str").drop(columns=["class_str"])
279+
267280
else:
268-
num_classes = len(df['class'].unique())
269-
cmap=plt.cm.get_cmap(color_map, num_classes)
270-
colors_raw = cmap((df['class']), bytes=True)
281+
num_classes = len(unique_classes)
282+
cmap = plt.cm.get_cmap(color_map, num_classes)
283+
colors_raw = cmap(df['class'], bytes=True)
271284
colors_str = ['#%02x%02x%02x' % tuple(c[:3]) for c in colors_raw]
272285
df['color'] = colors_str
273286

@@ -279,7 +292,7 @@ class labels of the images).
279292
<div>
280293
<div>
281294
<img
282-
src="@image_files" height="192" alt="image"
295+
src="@paths" height="192" alt="image"
283296
style="float: left; margin: 0px 15px 15px 0px; image-rendering: pixelated;"
284297
border="2"
285298
></img>
@@ -312,8 +325,8 @@ class labels of the images).
312325
const indices = hit_test_result.indices;
313326
if (indices.length > 0) {
314327
div.text = `<img
315-
src="${ds.data['image_files'][indices[0]]}"
316-
style="float: left; margin: 0px 15px 15px 0px; max-width: 650px; max-height: 650px; width: auto; height: auto;"
328+
src="${ds.data['paths'][indices[0]]}"
329+
style="float: left; margin: 0px 15px 15px 0px; max-width: 650px; max-height: 500px; width: auto; height: auto;"
317330
border="2"
318331
/>`;
319332
}

0 commit comments

Comments
 (0)