Skip to content

Commit 0d1e0ce

Browse files
committed
fixed interactive plots (img size), drop last = F
1 parent 62ee214 commit 0d1e0ce

5 files changed

Lines changed: 58 additions & 41 deletions

File tree

bioencoder/core/augmentations.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@ def get_transforms(config, valid=False):
1313
Returns:
1414
albumentations.core.composition.Compose: The image transformation pipeline.
1515
"""
16-
default_size = 224
17-
img_size = config.get('img_size', default_size)
16+
17+
img_size = config.get('img_size')
18+
if img_size is None:
19+
raise ValueError("config must include 'img_size'")
1820
config_aug = config.get('augmentations', {})
1921
aug = get_aug_from_config(config_aug.get('transforms', []))
2022

bioencoder/core/utils.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,9 @@ def build_transforms(config):
196196
}
197197

198198

199-
def build_loaders(data_dir, transforms, batch_sizes, num_workers, second_stage=False, is_supcon=False):
199+
def build_loaders(data_dir, transforms, batch_sizes, num_workers,
200+
second_stage=False, is_supcon=False,
201+
shuffle_train=True, drop_last=True):
200202
"""
201203
Build data loaders for training and validation.
202204
@@ -230,19 +232,19 @@ def build_loaders(data_dir, transforms, batch_sizes, num_workers, second_stage=F
230232
train_loader = torch.utils.data.DataLoader(
231233
train_features_dataset,
232234
batch_size=batch_sizes['train_batch_size'],
233-
shuffle=True,
235+
shuffle=shuffle_train,
234236
num_workers=num_workers,
235237
pin_memory=True,
236-
drop_last=(batch_sizes['train_batch_size'] is not None)
238+
drop_last=drop_last and batch_sizes['train_batch_size'] is not None
237239
)
238-
240+
239241
valid_loader = torch.utils.data.DataLoader(
240242
valid_dataset,
241243
batch_size=batch_sizes['valid_batch_size'],
242244
shuffle=False,
243245
num_workers=num_workers,
244246
pin_memory=True,
245-
drop_last=(batch_sizes['valid_batch_size'] is not None)
247+
drop_last=drop_last
246248
)
247249

248250
loaders = {

bioencoder/scripts/interactive_plots.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,6 @@ def interactive_plots(
7070
"point_size": hyperparams.get("point_size", 10),
7171
}
7272

73-
7473
## directories and file management
7574
data_dir = os.path.join(root_dir, "data", run_name)
7675
plot_dir = os.path.join(root_dir, "plots", run_name)
@@ -83,32 +82,28 @@ def interactive_plots(
8382
print(f"Checkpoint: using {checkpoint} of {stage} stage")
8483
ckpt_pretrained = os.path.join(root_dir, "weights", run_name, stage, checkpoint)
8584
utils.set_seed()
86-
transforms = utils.build_transforms(hyperparams)
87-
loaders = utils.build_loaders(data_dir, transforms, batch_sizes, num_workers, second_stage=(stage == "second"))
8885
model = utils.build_model(backbone, second_stage=(stage == "second"), num_classes=num_classes, ckpt_pretrained=ckpt_pretrained).cuda()
8986
model.use_projection_head(False)
9087
model.eval()
9188

92-
## Determine which embeddings to compute
89+
## prep computation
90+
transforms = utils.build_transforms(hyperparams)
91+
loaders = utils.build_loaders(
92+
data_dir, transforms, batch_sizes, num_workers,
93+
second_stage=(stage == "second"), drop_last=False, shuffle_train=False)
9394
embeddings, labels, rel_paths = [], [], []
9495

95-
## val batch size cant be zero
96+
## val set - batch size cant be zero
9697
embeddings_val, labels_val = utils.compute_embeddings(loaders["valid_loader"], model)
97-
if len(embeddings_val) < len(loaders["valid_loader"].dataset.imgs):
98-
missed_imgs = len(loaders["valid_loader"].dataset.imgs) - len(embeddings_val)
99-
print(f"Warning: missed {missed_imgs} images because batch size was not a multiple of validation dataset size.")
100-
rel_paths_val = [item[0][len(root_dir) + 1:] for item in loaders["valid_loader"].dataset.imgs[:len(embeddings_val)]]
98+
rel_paths_val = [item[0][len(root_dir) + 1:] for item in loaders["valid_loader"].dataset.imgs]
10199
embeddings.extend(embeddings_val)
102100
labels.extend(labels_val)
103101
rel_paths.extend(rel_paths_val)
104102

105-
## train set embeddings
103+
## train set - skipped if zero batch size
106104
if batch_sizes["train_batch_size"] is not None:
107105
embeddings_train, labels_train = utils.compute_embeddings(loaders["train_loader"], model)
108-
if len(embeddings_train) < len(loaders["train_loader"].dataset.imgs):
109-
missed_imgs = len(loaders["train_loader"].dataset.imgs) - len(embeddings_train)
110-
print(f"Warning: missed {missed_imgs} images because batch size was not a multiple of training dataset size.")
111-
rel_paths_train = [item[0][len(root_dir) + 1:] for item in loaders["train_loader"].dataset.imgs[:len(embeddings_train)]]
106+
rel_paths_train = [item[0][len(root_dir) + 1:] for item in loaders["train_loader"].dataset.imgs]
112107
embeddings.extend(embeddings_train)
113108
labels.extend(labels_train)
114109
rel_paths.extend(rel_paths_train)
@@ -120,7 +115,7 @@ def interactive_plots(
120115

121116
## Reduce dimensionality
122117
if not perplexity:
123-
perplexity = min(100, len(embeddings) // 2)
118+
perplexity = min(30, max(5, (len(embeddings) - 1) / 3))
124119
print(f"tSNE: using a perplexity value of {perplexity}")
125120
reduced_data, colnames, _ = helpers.embbedings_dimension_reductions(embeddings, perplexity)
126121

@@ -129,7 +124,6 @@ def interactive_plots(
129124
df["paths"] = [os.path.join("..", "..", p) for p in rel_paths]
130125
df["class"], df["class_str"] = labels, [os.path.basename(os.path.dirname(p)) for p in rel_paths]
131126
df["dataset"] = df["paths"].apply(lambda x: "validation" if "/val/" in x else "train")
132-
133127
helpers.bokeh_plot(df, out_path=plot_path, **plot_config)
134128

135129

bioencoder/vis/helpers.py

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,9 @@ class labels of the images).
264264
raise ValueError("The dataframe must have columns 'paths' and 'class'")
265265

266266
unique_classes = df['class'].unique()
267-
267+
unique_datasets = df['dataset'].unique()
268+
markers = ['circle', 'square'] # Define markers for each group
269+
268270
## Color management
269271
if color_classes:
270272
assert len(unique_classes) == len(color_classes), (
@@ -283,10 +285,12 @@ class labels of the images).
283285
colors_str = ['#%02x%02x%02x' % tuple(c[:3]) for c in colors_raw]
284286
df['color'] = colors_str
285287

288+
286289
source = ColumnDataSource(df)
287290
bplot.output_file(out_path)
288291

289292
if plot_style == 1:
293+
div = Div(text="")
290294
tooltip = """
291295
<div>
292296
<div>
@@ -305,18 +309,12 @@ class labels of the images).
305309
hover1 = HoverTool(tooltips=tooltip)
306310
tools0 = [t() for t in TOOLS] + [hover0]
307311
tools1 = [t() for t in TOOLS] + [hover1]
308-
pca = bplot.figure(tools=tools0)
309-
pca.scatter('PC1', 'PC2', color='color', source=source, size=point_size)
310-
tsne = bplot.figure(tools=tools1)
311-
tsne.scatter('tSNE-0', 'tSNE-1', color='color', source=source, size=point_size)
312-
p = bplot.gridplot([[pca, tsne]])
313-
bplot.show(p)
314-
312+
315313
elif plot_style == 2:
316314
div = Div(text="")
317315
hover=HoverTool(
318316
tooltips = [
319-
("class_str", "@class_str"),
317+
("Class", "@class_str"),
320318
]
321319
)
322320
hover.callback = CustomJS(args=dict(div=div, ds=source), code="""
@@ -332,11 +330,28 @@ class labels of the images).
332330
""")
333331
tools0 = [t() for t in TOOLS] + [hover]
334332
tools1 = [t() for t in TOOLS] + [hover]
335-
pca = bplot.figure(tools=tools0)
336-
pca.scatter('PC1', 'PC2', color='color', source=source, size=point_size)
337-
tsne = bplot.figure(tools=tools1)
338-
tsne.scatter('tSNE-0', 'tSNE-1', color='color', source=source, size=point_size)
339-
p = bplot.gridplot([[pca, tsne]])
340-
show(layouts.row(p, div))
333+
334+
# Create figures
335+
pca = bplot.figure(tools=tools0, title="PCA", match_aspect=True)
336+
tsne = bplot.figure(tools=tools1, title="t-SNE", match_aspect=True)
337+
338+
# Store renderers for dataset legend
339+
legend_items_dataset = []
340+
341+
# Scatter plots with different markers for datasets
342+
for dataset, marker in zip(unique_datasets, markers):
343+
dataset_source = ColumnDataSource(df[df['dataset'].astype(str) == dataset]) # Filter dataset-specific data
344+
r = pca.scatter('PC1', 'PC2', source=dataset_source, color='color', size=point_size, marker=marker)
345+
tsne.scatter('tSNE-0', 'tSNE-1', source=dataset_source, color='color', size=point_size, marker=marker)
346+
legend_items_dataset.append(LegendItem(label=str(dataset), renderers=[r]))
347+
348+
# Create and add horizontal legend for dataset markers
349+
legend_dataset = Legend(items=legend_items_dataset, orientation="horizontal")
350+
pca.add_layout(legend_dataset, 'below')
351+
352+
# Display plots
353+
p = bplot.gridplot([[pca, tsne]])
354+
show(layouts.row(p, div))
355+
341356

342357
return p

bioencoder_configs/plot_stage1.yml

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,15 @@ model:
44
stage: first # Training stage: 'first' for initial training, 'second' for fine-tuning
55

66
dataloaders:
7-
train_batch_size: # Larger is faster but may drop leftover data points - no value or removing this line will not include training data
8-
valid_batch_size: 10 # Larger is faster but may drop leftover data points - ideally use a multiple of val set size
7+
train_batch_size: 20 # Larger is faster; no value or removing this line will not include training data
8+
valid_batch_size: 20 # Larger is faster; val data is always plotted
99
num_workers: 32 # Should not exceed available CPU cores
1010

11-
plot_style: 1 # (1: pictogram above point, 2: pictogram next to plot panel)
11+
img_size: 384 # image size used for training
12+
13+
perplexity: 30 # for tSNE<; cannot be larger than dataset
14+
15+
plot_style: 2 # (1: pictogram above point, 2: pictogram next to plot panel)
1216
point_size: 10 ## size of points in scatter plot
1317
color_map: 'Set1' # Default color map; see https://matplotlib.org/stable/users/explain/colors/colormaps.html for options
1418
#color_classes: # overrides color_map

0 commit comments

Comments
 (0)