Skip to content

Commit 21971ea

Browse files
committed
Merge remote-tracking branch 'refs/remotes/origin/dev' into dev
2 parents bc8af3f + 3f572ea commit 21971ea

9 files changed

Lines changed: 92 additions & 71 deletions

File tree

bioencoder/core/augmentations.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import albumentations as A
33
from albumentations import pytorch as AT
44

5-
def get_transforms(config, valid=False):
5+
def get_transforms(config, no_aug=False):
66
"""
77
Return a transformation pipeline based on the provided configuration.
88
@@ -13,14 +13,16 @@ def get_transforms(config, valid=False):
1313
Returns:
1414
albumentations.core.composition.Compose: The image transformation pipeline.
1515
"""
16-
default_size = 224
17-
img_size = config.get('img_size', default_size)
16+
17+
img_size = config.get('img_size')
18+
if img_size is None:
19+
raise ValueError("config must include 'img_size'")
1820
config_aug = config.get('augmentations', {})
1921
aug = get_aug_from_config(config_aug.get('transforms', []))
2022

2123
return A.Compose([
2224
A.Resize(img_size, img_size, always_apply=True),
23-
A.NoOp() if valid else aug,
25+
A.NoOp() if no_aug else aug,
2426
A.Normalize(),
2527
AT.ToTensorV2()
2628
])

bioencoder/core/utils.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -188,15 +188,17 @@ def build_transforms(config):
188188
189189
"""
190190
train_transforms = get_transforms(config)
191-
valid_transforms = get_transforms(config, valid=True)
191+
valid_transforms = get_transforms(config, no_aug=True)
192192

193193
return {
194194
"train_transforms": train_transforms,
195195
"valid_transforms": valid_transforms
196196
}
197197

198198

199-
def build_loaders(data_dir, transforms, batch_sizes, num_workers, second_stage=False, is_supcon=False):
199+
def build_loaders(data_dir, transforms, batch_sizes, num_workers,
200+
second_stage=False, is_supcon=False,
201+
shuffle_train=True, drop_last=True):
200202
"""
201203
Build data loaders for training and validation.
202204
@@ -230,19 +232,19 @@ def build_loaders(data_dir, transforms, batch_sizes, num_workers, second_stage=F
230232
train_loader = torch.utils.data.DataLoader(
231233
train_features_dataset,
232234
batch_size=batch_sizes['train_batch_size'],
233-
shuffle=True,
235+
shuffle=shuffle_train,
234236
num_workers=num_workers,
235237
pin_memory=True,
236-
drop_last=(batch_sizes['train_batch_size'] is not None)
238+
drop_last=drop_last and batch_sizes['train_batch_size'] is not None
237239
)
238-
240+
239241
valid_loader = torch.utils.data.DataLoader(
240242
valid_dataset,
241243
batch_size=batch_sizes['valid_batch_size'],
242244
shuffle=False,
243245
num_workers=num_workers,
244246
pin_memory=True,
245-
drop_last=(batch_sizes['valid_batch_size'] is not None)
247+
drop_last=drop_last
246248
)
247249

248250
loaders = {

bioencoder/scripts/inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def inference(
8585
utils.set_seed()
8686

8787
## get transformations
88-
transform = utils.get_transforms(hyperparams, valid=False)
88+
transform = utils.get_transforms(hyperparams, no_aug=True)
8989

9090
## build model
9191
if config.model_path != ckpt_pretrained:

bioencoder/scripts/interactive_plots.py

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def interactive_plots(
6161
"valid_batch_size": hyperparams.get("dataloaders", {}).get("valid_batch_size",1),
6262
}
6363
num_workers = hyperparams.get("dataloaders", {}).get("num_workers", 4)
64-
perplexity = hyperparams.get("perplexity", 30)
64+
perplexity = hyperparams.get("perplexity")
6565

6666
plot_config = {
6767
"color_classes": hyperparams.get("color_classes", None),
@@ -70,43 +70,40 @@ def interactive_plots(
7070
"point_size": hyperparams.get("point_size", 10),
7171
}
7272

73-
74-
## Set up directories
73+
## directories and file management
7574
data_dir = os.path.join(root_dir, "data", run_name)
76-
plot_path = os.path.join(root_dir, "plots", run_name, f"embeddings_{run_name}.html")
75+
plot_dir = os.path.join(root_dir, "plots", run_name)
76+
os.makedirs(plot_dir, exist_ok=True)
77+
plot_path = os.path.join(plot_dir, "embeddings_interactive_plot.html")
7778
if not overwrite and not kwargs.get("ret_embeddings"):
78-
assert not os.path.isfile(plot_path), f"File exists: {plot_path}"
79+
assert not os.path.isfile(plot_path), f"File already exists: {plot_path}"
7980

8081
## Load model and set up
8182
print(f"Checkpoint: using {checkpoint} of {stage} stage")
8283
ckpt_pretrained = os.path.join(root_dir, "weights", run_name, stage, checkpoint)
8384
utils.set_seed()
84-
transforms = utils.build_transforms(hyperparams)
85-
loaders = utils.build_loaders(data_dir, transforms, batch_sizes, num_workers, second_stage=(stage == "second"))
8685
model = utils.build_model(backbone, second_stage=(stage == "second"), num_classes=num_classes, ckpt_pretrained=ckpt_pretrained).cuda()
8786
model.use_projection_head(False)
8887
model.eval()
8988

90-
## Determine which embeddings to compute
89+
## prep computation
90+
transforms = utils.build_transforms(hyperparams)
91+
loaders = utils.build_loaders(
92+
data_dir, transforms, batch_sizes, num_workers,
93+
second_stage=(stage == "second"), drop_last=False, shuffle_train=False)
9194
embeddings, labels, rel_paths = [], [], []
9295

93-
## val batch size cant be zero
96+
## val set - batch size cant be zero
9497
embeddings_val, labels_val = utils.compute_embeddings(loaders["valid_loader"], model)
95-
if len(embeddings_val) < len(loaders["valid_loader"].dataset.imgs):
96-
missed_imgs = len(loaders["valid_loader"].dataset.imgs) - len(embeddings_val)
97-
print(f"Warning: missed {missed_imgs} images because batch size was not a multiple of validation dataset size.")
98-
rel_paths_val = [item[0][len(root_dir) + 1:] for item in loaders["valid_loader"].dataset.imgs[:len(embeddings_val)]]
98+
rel_paths_val = [item[0][len(root_dir) + 1:] for item in loaders["valid_loader"].dataset.imgs]
9999
embeddings.extend(embeddings_val)
100100
labels.extend(labels_val)
101101
rel_paths.extend(rel_paths_val)
102102

103-
## train set embeddings
103+
## train set - skipped if zero batch size
104104
if batch_sizes["train_batch_size"] is not None:
105105
embeddings_train, labels_train = utils.compute_embeddings(loaders["train_loader"], model)
106-
if len(embeddings_train) < len(loaders["train_loader"].dataset.imgs):
107-
missed_imgs = len(loaders["train_loader"].dataset.imgs) - len(embeddings_train)
108-
print(f"Warning: missed {missed_imgs} images because batch size was not a multiple of training dataset size.")
109-
rel_paths_train = [item[0][len(root_dir) + 1:] for item in loaders["train_loader"].dataset.imgs[:len(embeddings_train)]]
106+
rel_paths_train = [item[0][len(root_dir) + 1:] for item in loaders["train_loader"].dataset.imgs]
110107
embeddings.extend(embeddings_train)
111108
labels.extend(labels_train)
112109
rel_paths.extend(rel_paths_train)
@@ -118,7 +115,7 @@ def interactive_plots(
118115

119116
## Reduce dimensionality
120117
if not perplexity:
121-
perplexity = min(100, len(embeddings) // 2)
118+
perplexity = min(30, max(5, (len(embeddings) - 1) / 3))
122119
print(f"tSNE: using a perplexity value of {perplexity}")
123120
reduced_data, colnames, _ = helpers.embbedings_dimension_reductions(embeddings, perplexity)
124121

@@ -127,7 +124,6 @@ def interactive_plots(
127124
df["paths"] = [os.path.join("..", "..", p) for p in rel_paths]
128125
df["class"], df["class_str"] = labels, [os.path.basename(os.path.dirname(p)) for p in rel_paths]
129126
df["dataset"] = df["paths"].apply(lambda x: "validation" if "/val/" in x else "train")
130-
131127
helpers.bokeh_plot(df, out_path=plot_path, **plot_config)
132128

133129

bioencoder/scripts/model_explorer.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,10 @@ def model_explorer(
7373
backbone = hyperparams["model"]["backbone"]
7474
num_classes = hyperparams["model"].get("num_classes", None)
7575
stage = hyperparams["model"]["stage"]
76-
76+
img_size = hyperparams.get("img_size", None)
77+
if img_size is None:
78+
raise ValueError("config must include 'img_size'")
79+
7780
## get swa path
7881
ckpt_pretrained = os.path.join(root_dir, "weights", run_name, stage, "swa")
7982

@@ -94,7 +97,7 @@ def model_explorer(
9497
uploaded_file = st.sidebar.file_uploader("Upload an Image", type=["png", "jpg", "jpeg"])
9598

9699
## get image transformations
97-
transform = utils.get_transforms(hyperparams, valid=False)
100+
transform = utils.get_transforms(hyperparams, no_aug=True)
98101

99102
# Load the model and add to cache
100103
model = load_model(
@@ -105,9 +108,13 @@ def model_explorer(
105108
)
106109

107110
if uploaded_file is not None:
111+
108112
# Display the uploaded image
109113
image = Image.open(uploaded_file).convert('RGB')
110114
st.sidebar.image(image, caption="Input Image", use_column_width=True)
115+
116+
# resize image
117+
image_resized = image.resize((img_size, img_size))
111118

112119
# Generate visualizations
113120
selected = option_menu(None, vis_funcs, icons=['list' for _ in range(len(vis_funcs))], menu_icon="cast", orientation="horizontal")
@@ -123,27 +130,31 @@ def model_explorer(
123130
layer = st.selectbox("Select a layer", layers.keys())
124131
module = layers[layer]
125132
max_acts = st.slider("Max activations", 5, 64, 25)
126-
result = vis.visualize_activations(model, module, image, max_acts=max_acts)
133+
result = vis.visualize_activations(model, module, image_resized, max_acts=max_acts)
127134
st.pyplot(result)
128135

129136
elif selected == 'Saliency':
130-
result = vis.saliency_map(model, image)
137+
result = vis.saliency_map(model, image_resized)
131138
st.pyplot(result)
132139

133140
elif selected == 'GradCAM':
134141
# add activation type (Relu, Silu, etc_)
135-
layers =[name.split('.')[0] for name, module in model.encoder.named_modules() if isinstance(module, (torch.nn.SiLU, torch.nn.ReLU))]
142+
layers =[name.split('.')[0] for name, module in model.encoder.named_modules() \
143+
if isinstance(module, (torch.nn.SiLU, torch.nn.ReLU))]
136144
layer_set = sorted(set(layers))
137145
layer = st.selectbox("Select a layer", list(layer_set), index=len(list(layer_set))-1)
138-
result = vis.grad_cam(model, model.encoder,image,target_layer=[layer], target_category= None)
146+
result = vis.grad_cam(model, model.encoder,image_resized,target_layer=[layer], target_category= None)
139147
st.pyplot(result)
140148

141149
elif selected == 'ConstrativeCAM':
142-
layers =[name.split('.')[0] for name, module in model.encoder.named_modules() if isinstance(module, (torch.nn.SiLU, torch.nn.ReLU))]
150+
layers =[name.split('.')[0] for name, module in model.encoder.named_modules() \
151+
if isinstance(module, (torch.nn.SiLU, torch.nn.ReLU))]
143152
layer_set = sorted(set(layers))
144153
layer = st.selectbox("Select a layer", list(layer_set), index=len(list(layer_set))-1)
145154
target = st.selectbox("Select a target", class_names)
146-
result = vis.contrast_cam(model, model.encoder,image,target_layer=[layer], target_category=class_names.index(target))
155+
result = vis.contrast_cam(
156+
model, model.encoder, image_resized,target_layer=[layer],
157+
target_category=class_names.index(target))
147158
st.pyplot(result)
148159

149160

bioencoder/scripts/model_explorer_wrapper.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,6 @@ def model_explorer_wrapper(config_path):
1515
process = ["streamlit", "run", script_path , "--", "--config-path", config_path]
1616
subprocess.run(process)
1717

18-
19-
20-
2118
def cli():
2219

2320
parser = argparse.ArgumentParser()

bioencoder/vis/helpers.py

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ def embbedings_dimension_reductions(data_table, perplexity):
242242
return np.hstack((pca, tsne)), names, pca_obj
243243

244244

245-
def bokeh_plot(df, out_path='plot.html', color_map="viridis", color_classes=None, plot_style=1,
245+
def bokeh_plot(df, out_path='plot.html', color_map="jet1", color_classes=None, plot_style=1,
246246
point_size=10, **kwargs):
247247
"""
248248
Plot a scatter plot of the PCA and t-SNE dimensions of the data using bokeh.
@@ -264,8 +264,9 @@ class labels of the images).
264264
raise ValueError("The dataframe must have columns 'paths' and 'class'")
265265

266266
unique_classes = df['class'].unique()
267-
268-
267+
unique_datasets = df['dataset'].unique()
268+
markers = ['circle', 'square'] # Define markers for each group
269+
269270
## Color management
270271
if color_classes:
271272
assert len(unique_classes) == len(color_classes), (
@@ -284,10 +285,12 @@ class labels of the images).
284285
colors_str = ['#%02x%02x%02x' % tuple(c[:3]) for c in colors_raw]
285286
df['color'] = colors_str
286287

288+
287289
source = ColumnDataSource(df)
288290
bplot.output_file(out_path)
289291

290292
if plot_style == 1:
293+
div = Div(text="")
291294
tooltip = """
292295
<div>
293296
<div>
@@ -306,18 +309,12 @@ class labels of the images).
306309
hover1 = HoverTool(tooltips=tooltip)
307310
tools0 = [t() for t in TOOLS] + [hover0]
308311
tools1 = [t() for t in TOOLS] + [hover1]
309-
pca = bplot.figure(tools=tools0)
310-
pca.scatter('PC1', 'PC2', color='color', source=source, size=point_size)
311-
tsne = bplot.figure(tools=tools1)
312-
tsne.scatter('tSNE-0', 'tSNE-1', color='color', source=source, size=point_size)
313-
p = bplot.gridplot([[pca, tsne]])
314-
bplot.show(p)
315-
312+
316313
elif plot_style == 2:
317314
div = Div(text="")
318315
hover=HoverTool(
319316
tooltips = [
320-
("class_str", "@class_str"),
317+
("Class", "@class_str"),
321318
]
322319
)
323320
hover.callback = CustomJS(args=dict(div=div, ds=source), code="""
@@ -333,11 +330,28 @@ class labels of the images).
333330
""")
334331
tools0 = [t() for t in TOOLS] + [hover]
335332
tools1 = [t() for t in TOOLS] + [hover]
336-
pca = bplot.figure(tools=tools0)
337-
pca.scatter('PC1', 'PC2', color='color', source=source, size=point_size)
338-
tsne = bplot.figure(tools=tools1)
339-
tsne.scatter('tSNE-0', 'tSNE-1', color='color', source=source, size=point_size)
340-
p = bplot.gridplot([[pca, tsne]])
341-
show(layouts.row(p, div))
333+
334+
# Create figures
335+
pca = bplot.figure(tools=tools0, title="PCA", match_aspect=True)
336+
tsne = bplot.figure(tools=tools1, title="t-SNE", match_aspect=True)
337+
338+
# Store renderers for dataset legend
339+
legend_items_dataset = []
340+
341+
# Scatter plots with different markers for datasets
342+
for dataset, marker in zip(unique_datasets, markers):
343+
dataset_source = ColumnDataSource(df[df['dataset'].astype(str) == dataset]) # Filter dataset-specific data
344+
r = pca.scatter('PC1', 'PC2', source=dataset_source, color='color', size=point_size, marker=marker)
345+
tsne.scatter('tSNE-0', 'tSNE-1', source=dataset_source, color='color', size=point_size, marker=marker)
346+
legend_items_dataset.append(LegendItem(label=str(dataset), renderers=[r]))
347+
348+
# Create and add horizontal legend for dataset markers
349+
legend_dataset = Legend(items=legend_items_dataset, orientation="horizontal")
350+
pca.add_layout(legend_dataset, 'below')
351+
352+
# Display plots
353+
p = bplot.gridplot([[pca, tsne]])
354+
show(layouts.row(p, div))
355+
342356

343357
return p

bioencoder/vis/methods.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,8 @@ def saliency_map(model, img, device = 'cuda', save_path = None):
110110
plt.savefig(save_path)
111111
return fig
112112

113-
def grad_cam(model, module, img, target_layer = ["4"], target_category= None, device = 'cuda', save_path = None):
113+
def grad_cam(model, module, img, target_layer = ["4"],
114+
target_category= None, device = 'cuda', save_path = None):
114115

115116
for param in model.parameters():
116117
param.requires_grad = True

bioencoder_configs/plot_stage1.yml

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,19 @@ model:
44
stage: first # Training stage: 'first' for initial training, 'second' for fine-tuning
55

66
dataloaders:
7-
train_batch_size: 2 # Batch size for training data; ensure validation set size is a multiple of this number
8-
valid_batch_size: 2 # Batch size for validation data
9-
num_workers: 32 # Number of CPU threads for data loading; should not exceed the number of CPU cores
7+
train_batch_size: 20 # Larger is faster; no value or removing this line will not include training data
8+
valid_batch_size: 20 # Larger is faster; val data is always plotted
9+
num_workers: 32 # Should not exceed available CPU cores
1010

11-
img_size: 384 # Image size for training and validation
11+
img_size: 384 # image size used for training
1212

13-
plot_style: 1 # (1: pictogram above point, 2: pictogram next to plot panel)
14-
15-
color_classes: # overrides color_map
16-
#class1: "#FFD467"
17-
#class2: "#4DC9F2"
13+
perplexity: 30 # for tSNE<; cannot be larger than dataset
1814

15+
plot_style: 2 # (1: pictogram above point, 2: pictogram next to plot panel)
1916
point_size: 10 ## size of points in scatter plot
20-
2117
color_map: 'Set1' # Default color map; see https://matplotlib.org/stable/users/explain/colors/colormaps.html for options
22-
18+
#color_classes: # overrides color_map
19+
#class1: "#FFD467"
20+
#class2: "#4DC9F2"
2321

2422

0 commit comments

Comments
 (0)