Skip to content

Commit 632879b

Browse files
committed
fix interactive plotting (for real this time)
1 parent fa54d46 commit 632879b

3 files changed

Lines changed: 80 additions & 47 deletions

File tree

bioencoder/scripts/interactive_plots.py

Lines changed: 43 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -69,19 +69,21 @@ def interactive_plots(
6969
"plot_style": hyperparams.get("plot_style", 1),
7070
"point_size": hyperparams.get("point_size", 10),
7171
}
72-
72+
73+
return_results = hyperparams.get("return_results", False)
74+
7375
## directories and file management
7476
data_dir = os.path.join(root_dir, "data", run_name)
7577
plot_dir = os.path.join(root_dir, "plots", run_name)
7678
os.makedirs(plot_dir, exist_ok=True)
7779
plot_path = os.path.join(plot_dir, "embeddings_interactive_plot.html")
78-
if not overwrite and not kwargs.get("ret_embeddings"):
80+
if not overwrite and not (return_embeddings or return_coords):
7981
assert not os.path.isfile(plot_path), f"File already exists: {plot_path}"
8082

8183
## Load model and set up
8284
print(f"Checkpoint: using {checkpoint} of {stage} stage")
8385
ckpt_pretrained = os.path.join(root_dir, "weights", run_name, stage, checkpoint)
84-
utils.set_seed()
86+
seed = utils.set_seed()
8587
model = utils.build_model(backbone, second_stage=(stage == "second"), num_classes=num_classes, ckpt_pretrained=ckpt_pretrained).cuda()
8688
model.use_projection_head(False)
8789
model.eval()
@@ -91,40 +93,57 @@ def interactive_plots(
9193
loaders = utils.build_loaders(
9294
data_dir, transforms, batch_sizes, num_workers,
9395
second_stage=(stage == "second"), drop_last=False, shuffle_train=False)
94-
embeddings, labels, rel_paths = [], [], []
9596

96-
## val set - batch size cant be zero
97+
## val set (always computed)
9798
embeddings_val, labels_val = utils.compute_embeddings(loaders["valid_loader"], model)
9899
rel_paths_val = [item[0][len(root_dir) + 1:] for item in loaders["valid_loader"].dataset.imgs]
99-
embeddings.extend(embeddings_val)
100-
labels.extend(labels_val)
101-
rel_paths.extend(rel_paths_val)
100+
# Build validation DataFrame (meta + embeddings)
101+
df_val_meta = pd.DataFrame({
102+
"image_name": [os.path.basename(p) for p in rel_paths_val],
103+
"class_str": [os.path.basename(os.path.dirname(p)) for p in rel_paths_val],
104+
"dataset": "val",
105+
})
106+
df_embeddings = pd.concat([df_val_meta, pd.DataFrame(embeddings_val)], axis=1)
102107

103108
## train set - skipped if zero batch size
104109
if batch_sizes["train_batch_size"] is not None:
105110
embeddings_train, labels_train = utils.compute_embeddings(loaders["train_loader"], model)
106111
rel_paths_train = [item[0][len(root_dir) + 1:] for item in loaders["train_loader"].dataset.imgs]
107-
embeddings.extend(embeddings_train)
108-
labels.extend(labels_train)
109-
rel_paths.extend(rel_paths_train)
110-
111-
## Return embeddings without plotting
112-
if kwargs.get("ret_embeddings"):
113-
df = pd.DataFrame({"image_name": [os.path.basename(p) for p in rel_paths], "class": [os.path.basename(os.path.dirname(p)) for p in rel_paths]})
114-
return pd.concat([df, pd.DataFrame(embeddings)], axis=1)
115112

113+
# Build training DataFrame (meta + embeddings)
114+
df_train_meta = pd.DataFrame({
115+
"image_name": [os.path.basename(p) for p in rel_paths_train],
116+
"class_str": [os.path.basename(os.path.dirname(p)) for p in rel_paths_train],
117+
"dataset": "train",
118+
})
119+
df_train = pd.concat([df_train_meta, pd.DataFrame(embeddings_train)], axis=1)
120+
df_embeddings = pd.concat([df_embeddings, df_train], ignore_index=True)
121+
122+
## Stable order before reduction
123+
df_embeddings = df_embeddings.sort_values(by=["class_str", "dataset","image_name"]).reset_index(drop=True)
124+
116125
## Reduce dimensionality
117126
if not perplexity:
118-
perplexity = min(30, max(5, (len(embeddings) - 1) / 3))
119-
print(f"tSNE: using a perplexity value of {perplexity}")
120-
reduced_data, colnames, _ = helpers.embbedings_dimension_reductions(embeddings, perplexity)
127+
perplexity = min(30, max(5, (len(df_embeddings) - 1) / 3))
128+
print(f"tSNE: using perplexity {perplexity}")
129+
# Reduce on numeric embedding columns only
130+
embedding_matrix = df_embeddings.select_dtypes(include=[np.number])
131+
reduced_data, colnames, _ = helpers.embbedings_dimension_reductions(embedding_matrix, perplexity, seed)
121132

122133
## make plot
123-
df = pd.DataFrame(reduced_data, columns=colnames)
124-
df["paths"] = [os.path.join("..", "..", p) for p in rel_paths]
125-
df["class"], df["class_str"] = labels, [os.path.basename(os.path.dirname(p)) for p in rel_paths]
126-
df["dataset"] = df["paths"].apply(lambda x: "validation" if "/val/" in x else "train")
127-
helpers.bokeh_plot(df, out_path=plot_path, **plot_config)
134+
df_plot = df_embeddings.select_dtypes(exclude=[np.number])
135+
df_plot['paths'] = df_plot.apply(lambda row: os.path.join(
136+
"..", "..", "data", run_name, row['dataset'], row['class_str'], row['image_name']), axis=1)
137+
df_plot["class"] = pd.Categorical(df_plot["class_str"]).codes
138+
df_plot = pd.concat([df_plot, pd.DataFrame(reduced_data, columns=colnames)], axis=1)
139+
140+
helpers.bokeh_plot(df_plot, out_path=plot_path, **plot_config)
141+
142+
# Return logic: either one or both
143+
if return_results:
144+
return df_embeddings, df_plot
145+
146+
128147

129148

130149
def cli():

bioencoder/vis/helpers.py

Lines changed: 34 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ def feature_map_normalization(f):
220220
act_map /= act_map.max()
221221
return act_map
222222

223-
def embbedings_dimension_reductions(data_table, perplexity):
223+
def embbedings_dimension_reductions(data_table, perplexity, seed):
224224
"""
225225
Perform dimension reduction on the input data.
226226
@@ -235,14 +235,25 @@ def embbedings_dimension_reductions(data_table, perplexity):
235235
mean = np.mean(data_table, axis=0)
236236
std = np.std(data_table, axis=0)
237237
norm_data = (data_table - mean) / std
238+
239+
## PCA
238240
pca_obj = decomposition.PCA(n_components=2)
239241
pca = pca_obj.fit_transform(norm_data)
240-
tsne = manifold.TSNE(perplexity=perplexity, learning_rate='auto', init='pca').fit_transform(norm_data)
242+
243+
## tSNE
244+
tsne = manifold.TSNE(
245+
perplexity=perplexity,
246+
random_state=seed,
247+
learning_rate='auto',
248+
method="exact",
249+
init=pca
250+
).fit_transform(norm_data)
251+
241252
names = ['PC1', 'PC2', 'tSNE-0', 'tSNE-1']
242253
return np.hstack((pca, tsne)), names, pca_obj
243254

244255

245-
def bokeh_plot(df, out_path='plot.html', color_map="jet1", color_classes=None, plot_style=1,
256+
def bokeh_plot(df, out_path='plot.html', color_map="jet", color_classes=None, plot_style=1,
246257
point_size=10, **kwargs):
247258
"""
248259
Plot a scatter plot of the PCA and t-SNE dimensions of the data using bokeh.
@@ -264,7 +275,7 @@ class labels of the images).
264275
raise ValueError("The dataframe must have columns 'paths' and 'class'")
265276

266277
unique_classes = df['class'].unique()
267-
unique_datasets = df['dataset'].unique()
278+
unique_datasets = df['dataset'].astype(str).unique()
268279
markers = ['circle', 'square'] # Define markers for each group
269280

270281
## Color management
@@ -273,10 +284,9 @@ class labels of the images).
273284
f"Number of classes is {len(unique_classes)}, but only {len(color_classes)} colors provided."
274285
)
275286

276-
# Convert dict to DataFrame and merge colors
277-
df_col = pd.DataFrame.from_dict(color_classes.items())
278-
df_col.columns = ["class_str", "color"]
279-
df = df.merge(df_col, how="left", left_on="class", right_on="class_str").drop(columns=["class_str"])
287+
# Convert dict to DataFrame and merge colors by class_str (deterministic, no row reordering)
288+
df_col = pd.DataFrame(list(color_classes.items()), columns=["class_str", "color"])
289+
df = df.merge(df_col, how="left", on="class_str")
280290

281291
else:
282292
num_classes = len(unique_classes)
@@ -285,7 +295,6 @@ class labels of the images).
285295
colors_str = ['#%02x%02x%02x' % tuple(c[:3]) for c in colors_raw]
286296
df['color'] = colors_str
287297

288-
289298
source = ColumnDataSource(df)
290299
bplot.output_file(out_path)
291300

@@ -335,19 +344,22 @@ class labels of the images).
335344
pca = bplot.figure(tools=tools0, title="PCA", match_aspect=True)
336345
tsne = bplot.figure(tools=tools1, title="t-SNE", match_aspect=True)
337346

338-
# Store renderers for dataset legend
339-
legend_items_dataset = []
340-
341-
# Scatter plots with different markers for datasets
342-
for dataset, marker in zip(unique_datasets, markers):
343-
dataset_source = ColumnDataSource(df[df['dataset'].astype(str) == dataset]) # Filter dataset-specific data
344-
r = pca.scatter('PC1', 'PC2', source=dataset_source, color='color', size=point_size, marker=marker)
345-
tsne.scatter('tSNE-0', 'tSNE-1', source=dataset_source, color='color', size=point_size, marker=marker)
346-
legend_items_dataset.append(LegendItem(label=str(dataset), renderers=[r]))
347-
348-
# Create and add horizontal legend for dataset markers
349-
legend_dataset = Legend(items=legend_items_dataset, orientation="horizontal")
350-
pca.add_layout(legend_dataset, 'below')
347+
# Single source scatter with per-point markers mapped from dataset; no reordering
348+
from itertools import cycle, islice
349+
dataset_factors = list(pd.unique(df['dataset'].astype(str)))
350+
marker_factors = list(islice(cycle(markers), len(dataset_factors)))
351+
marker_map = factor_mark('dataset', marker_factors, dataset_factors)
352+
353+
pca.scatter('PC1', 'PC2', source=source, color='color', size=point_size, marker=marker_map, legend_field='dataset')
354+
tsne.scatter('tSNE-0', 'tSNE-1', source=source, color='color', size=point_size, marker=marker_map)
355+
356+
# Single legend below the PCA plot
357+
if getattr(pca, 'legend', None) and len(pca.legend) > 0:
358+
pca.legend[0].orientation = "horizontal"
359+
pca.legend[0].location = "left"
360+
pca.add_layout(pca.legend[0], 'below')
361+
if getattr(tsne, 'legend', None):
362+
tsne.legend.visible = False
351363

352364
# Display plots
353365
p = bplot.gridplot([[pca, tsne]])

bioencoder_configs/plot_stage1.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,7 @@ color_map: 'Set1' # Default color map; see https://matplotlib.org/stable/users/e
1818
#color_classes: # overrides color_map
1919
#class1: "#FFD467"
2020
#class2: "#4DC9F2"
21-
21+
22+
return_results: False
23+
2224

0 commit comments

Comments
 (0)