Skip to content

Commit 3f572ea

Browse files
committed
model explorer img size
1 parent 0d1e0ce commit 3f572ea

6 files changed

Lines changed: 25 additions & 16 deletions

File tree

bioencoder/core/augmentations.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import albumentations as A
33
from albumentations import pytorch as AT
44

5-
def get_transforms(config, valid=False):
5+
def get_transforms(config, no_aug=False):
66
"""
77
Return a transformation pipeline based on the provided configuration.
88
@@ -22,7 +22,7 @@ def get_transforms(config, valid=False):
2222

2323
return A.Compose([
2424
A.Resize(img_size, img_size, always_apply=True),
25-
A.NoOp() if valid else aug,
25+
A.NoOp() if no_aug else aug,
2626
A.Normalize(),
2727
AT.ToTensorV2()
2828
])

bioencoder/core/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def build_transforms(config):
188188
189189
"""
190190
train_transforms = get_transforms(config)
191-
valid_transforms = get_transforms(config, valid=True)
191+
valid_transforms = get_transforms(config, no_aug=True)
192192

193193
return {
194194
"train_transforms": train_transforms,

bioencoder/scripts/inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def inference(
8585
utils.set_seed()
8686

8787
## get transformations
88-
transform = utils.get_transforms(hyperparams, valid=False)
88+
transform = utils.get_transforms(hyperparams, no_aug=True)
8989

9090
## build model
9191
if config.model_path != ckpt_pretrained:

bioencoder/scripts/model_explorer.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,10 @@ def model_explorer(
7373
backbone = hyperparams["model"]["backbone"]
7474
num_classes = hyperparams["model"].get("num_classes", None)
7575
stage = hyperparams["model"]["stage"]
76-
76+
img_size = hyperparams.get("img_size", None)
77+
if img_size is None:
78+
raise ValueError("config must include 'img_size'")
79+
7780
## get swa path
7881
ckpt_pretrained = os.path.join(root_dir, "weights", run_name, stage, "swa")
7982

@@ -94,7 +97,7 @@ def model_explorer(
9497
uploaded_file = st.sidebar.file_uploader("Upload an Image", type=["png", "jpg", "jpeg"])
9598

9699
## get image transformations
97-
transform = utils.get_transforms(hyperparams, valid=False)
100+
transform = utils.get_transforms(hyperparams, no_aug=True)
98101

99102
# Load the model and add to cache
100103
model = load_model(
@@ -105,9 +108,13 @@ def model_explorer(
105108
)
106109

107110
if uploaded_file is not None:
111+
108112
# Display the uploaded image
109113
image = Image.open(uploaded_file).convert('RGB')
110114
st.sidebar.image(image, caption="Input Image", use_column_width=True)
115+
116+
# resize image
117+
image_resized = image.resize((img_size, img_size))
111118

112119
# Generate visualizations
113120
selected = option_menu(None, vis_funcs, icons=['list' for _ in range(len(vis_funcs))], menu_icon="cast", orientation="horizontal")
@@ -123,27 +130,31 @@ def model_explorer(
123130
layer = st.selectbox("Select a layer", layers.keys())
124131
module = layers[layer]
125132
max_acts = st.slider("Max activations", 5, 64, 25)
126-
result = vis.visualize_activations(model, module, image, max_acts=max_acts)
133+
result = vis.visualize_activations(model, module, image_resized, max_acts=max_acts)
127134
st.pyplot(result)
128135

129136
elif selected == 'Saliency':
130-
result = vis.saliency_map(model, image)
137+
result = vis.saliency_map(model, image_resized)
131138
st.pyplot(result)
132139

133140
elif selected == 'GradCAM':
134141
# add activation type (Relu, Silu, etc_)
135-
layers =[name.split('.')[0] for name, module in model.encoder.named_modules() if isinstance(module, (torch.nn.SiLU, torch.nn.ReLU))]
142+
layers =[name.split('.')[0] for name, module in model.encoder.named_modules() \
143+
if isinstance(module, (torch.nn.SiLU, torch.nn.ReLU))]
136144
layer_set = sorted(set(layers))
137145
layer = st.selectbox("Select a layer", list(layer_set), index=len(list(layer_set))-1)
138-
result = vis.grad_cam(model, model.encoder,image,target_layer=[layer], target_category= None)
146+
result = vis.grad_cam(model, model.encoder,image_resized,target_layer=[layer], target_category= None)
139147
st.pyplot(result)
140148

141149
elif selected == 'ConstrativeCAM':
142-
layers =[name.split('.')[0] for name, module in model.encoder.named_modules() if isinstance(module, (torch.nn.SiLU, torch.nn.ReLU))]
150+
layers =[name.split('.')[0] for name, module in model.encoder.named_modules() \
151+
if isinstance(module, (torch.nn.SiLU, torch.nn.ReLU))]
143152
layer_set = sorted(set(layers))
144153
layer = st.selectbox("Select a layer", list(layer_set), index=len(list(layer_set))-1)
145154
target = st.selectbox("Select a target", class_names)
146-
result = vis.contrast_cam(model, model.encoder,image,target_layer=[layer], target_category=class_names.index(target))
155+
result = vis.contrast_cam(
156+
model, model.encoder, image_resized,target_layer=[layer],
157+
target_category=class_names.index(target))
147158
st.pyplot(result)
148159

149160

bioencoder/scripts/model_explorer_wrapper.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,6 @@ def model_explorer_wrapper(config_path):
1515
process = ["streamlit", "run", script_path , "--", "--config-path", config_path]
1616
subprocess.run(process)
1717

18-
19-
20-
2118
def cli():
2219

2320
parser = argparse.ArgumentParser()

bioencoder/vis/methods.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,8 @@ def saliency_map(model, img, device = 'cuda', save_path = None):
110110
plt.savefig(save_path)
111111
return fig
112112

113-
def grad_cam(model, module, img, target_layer = ["4"], target_category= None, device = 'cuda', save_path = None):
113+
def grad_cam(model, module, img, target_layer = ["4"],
114+
target_category= None, device = 'cuda', save_path = None):
114115

115116
for param in model.parameters():
116117
param.requires_grad = True

0 commit comments

Comments
 (0)