@@ -73,7 +73,10 @@ def model_explorer(
7373 backbone = hyperparams ["model" ]["backbone" ]
7474 num_classes = hyperparams ["model" ].get ("num_classes" , None )
7575 stage = hyperparams ["model" ]["stage" ]
76-
76+ img_size = hyperparams .get ("img_size" , None )
77+ if img_size is None :
78+ raise ValueError ("config must include 'img_size'" )
79+
7780 ## get swa path
7881 ckpt_pretrained = os .path .join (root_dir , "weights" , run_name , stage , "swa" )
7982
@@ -94,7 +97,7 @@ def model_explorer(
9497 uploaded_file = st .sidebar .file_uploader ("Upload an Image" , type = ["png" , "jpg" , "jpeg" ])
9598
9699 ## get image transformations
97- transform = utils .get_transforms (hyperparams , valid = False )
100+ transform = utils .get_transforms (hyperparams , no_aug = True )
98101
99102 # Load the model and add to cache
100103 model = load_model (
@@ -105,9 +108,13 @@ def model_explorer(
105108 )
106109
107110 if uploaded_file is not None :
111+
108112 # Display the uploaded image
109113 image = Image .open (uploaded_file ).convert ('RGB' )
110114 st .sidebar .image (image , caption = "Input Image" , use_column_width = True )
115+
116+ # resize image
117+ image_resized = image .resize ((img_size , img_size ))
111118
112119 # Generate visualizations
113120 selected = option_menu (None , vis_funcs , icons = ['list' for _ in range (len (vis_funcs ))], menu_icon = "cast" , orientation = "horizontal" )
@@ -123,27 +130,31 @@ def model_explorer(
123130 layer = st .selectbox ("Select a layer" , layers .keys ())
124131 module = layers [layer ]
125132 max_acts = st .slider ("Max activations" , 5 , 64 , 25 )
126- result = vis .visualize_activations (model , module , image , max_acts = max_acts )
133+ result = vis .visualize_activations (model , module , image_resized , max_acts = max_acts )
127134 st .pyplot (result )
128135
129136 elif selected == 'Saliency' :
130- result = vis .saliency_map (model , image )
137+ result = vis .saliency_map (model , image_resized )
131138 st .pyplot (result )
132139
133140 elif selected == 'GradCAM' :
134141 # add activation type (Relu, Silu, etc_)
135- layers = [name .split ('.' )[0 ] for name , module in model .encoder .named_modules () if isinstance (module , (torch .nn .SiLU , torch .nn .ReLU ))]
142+ layers = [name .split ('.' )[0 ] for name , module in model .encoder .named_modules () \
143+ if isinstance (module , (torch .nn .SiLU , torch .nn .ReLU ))]
136144 layer_set = sorted (set (layers ))
137145 layer = st .selectbox ("Select a layer" , list (layer_set ), index = len (list (layer_set ))- 1 )
138- result = vis .grad_cam (model , model .encoder ,image ,target_layer = [layer ], target_category = None )
146+ result = vis .grad_cam (model , model .encoder ,image_resized ,target_layer = [layer ], target_category = None )
139147 st .pyplot (result )
140148
141149 elif selected == 'ConstrativeCAM' :
142- layers = [name .split ('.' )[0 ] for name , module in model .encoder .named_modules () if isinstance (module , (torch .nn .SiLU , torch .nn .ReLU ))]
150+ layers = [name .split ('.' )[0 ] for name , module in model .encoder .named_modules () \
151+ if isinstance (module , (torch .nn .SiLU , torch .nn .ReLU ))]
143152 layer_set = sorted (set (layers ))
144153 layer = st .selectbox ("Select a layer" , list (layer_set ), index = len (list (layer_set ))- 1 )
145154 target = st .selectbox ("Select a target" , class_names )
146- result = vis .contrast_cam (model , model .encoder ,image ,target_layer = [layer ], target_category = class_names .index (target ))
155+ result = vis .contrast_cam (
156+ model , model .encoder , image_resized ,target_layer = [layer ],
157+ target_category = class_names .index (target ))
147158 st .pyplot (result )
148159
149160
0 commit comments