Skip to content

Commit 4ac3c2a

Browse files
committed
fix class order in explorer / vis deprec. warnings
1 parent 92f1db1 commit 4ac3c2a

3 files changed

Lines changed: 16 additions & 14 deletions

File tree

bioencoder/scripts/interactive_plots.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def interactive_plots(
7777
plot_dir = os.path.join(root_dir, "plots", run_name)
7878
os.makedirs(plot_dir, exist_ok=True)
7979
plot_path = os.path.join(plot_dir, "embeddings_interactive_plot.html")
80-
if not overwrite and not (return_embeddings or return_coords):
80+
if not overwrite and not return_results:
8181
assert not os.path.isfile(plot_path), f"File already exists: {plot_path}"
8282

8383
## Load model and set up

bioencoder/scripts/model_explorer.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import torch
1010
import streamlit as st
1111
from streamlit_option_menu import option_menu
12+
from torchvision.datasets import ImageFolder
1213

1314
from bioencoder import config, utils, vis
1415

@@ -63,9 +64,7 @@ def model_explorer(
6364
## load bioencoer config
6465
root_dir = config.root_dir
6566
run_name = config.run_name
66-
67-
class_names = os.listdir(os.path.join(root_dir, "data", run_name, "train"))
68-
67+
6968
## load config
7069
hyperparams = utils.load_yaml(config_path)
7170

@@ -79,7 +78,6 @@ def model_explorer(
7978

8079
## get swa path
8180
ckpt_pretrained = os.path.join(root_dir, "weights", run_name, stage, "swa")
82-
8381
if stage == 'first':
8482
vis_funcs = ['Filters', 'Activations', 'Saliency']
8583
else:
@@ -90,28 +88,30 @@ def model_explorer(
9088

9189
# Sidebar
9290
img_path = "https://github.com/agporto/BioEncoder/raw/main/assets/bioencoder_logo.png"
93-
st.sidebar.image(img_path, use_column_width=True)
91+
st.sidebar.image(img_path, width='stretch')
9492
st.sidebar.title("BioEncoder Model Explorer")
9593

9694
# Image upload
9795
uploaded_file = st.sidebar.file_uploader("Upload an Image", type=["png", "jpg", "jpeg"])
9896

99-
## get image transformations
100-
transform = utils.get_transforms(hyperparams, no_aug=True)
101-
10297
# Load the model and add to cache
10398
model = load_model(
10499
ckpt_pretrained,
105100
backbone,
106101
num_classes,
107102
stage
108103
)
109-
104+
105+
## get class names
106+
train_folder = os.path.join(root_dir, "data", run_name, "train")
107+
train_folder_timm = ImageFolder(train_folder)
108+
class_names = train_folder_timm.classes
109+
110110
if uploaded_file is not None:
111111

112112
# Display the uploaded image
113113
image = Image.open(uploaded_file).convert('RGB')
114-
st.sidebar.image(image, caption="Input Image", use_column_width=True)
114+
st.sidebar.image(image, caption="Input Image", width='stretch')
115115

116116
# resize image
117117
image_resized = image.resize((img_size, img_size))

bioencoder/scripts/train.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,11 @@ def train(
8484
"valid_batch_size": hyperparams["dataloaders"]["valid_batch_size"],
8585
}
8686
num_workers = hyperparams["dataloaders"]["num_workers"]
87-
aug_sample = hyperparams["augmentations"].get("sample_save", False)
88-
aug_sample_n = hyperparams["augmentations"].get("sample_n", 5)
89-
aug_sample_seed = hyperparams["augmentations"].get("sample_seed", 42)
87+
aug_config = hyperparams.get("augmentations", {})
88+
aug_sample = aug_config.get("sample_save", False)
89+
aug_sample_n = aug_config.get("sample_n", 5)
90+
aug_sample_seed = aug_config.get("sample_seed", 42)
91+
9092

9193
## manage directories and paths
9294
data_dir = os.path.join(root_dir, "data", run_name)

0 commit comments

Comments
 (0)