99import torch
1010import streamlit as st
1111from streamlit_option_menu import option_menu
12+ from torchvision .datasets import ImageFolder
1213
1314from bioencoder import config , utils , vis
1415
@@ -63,9 +64,7 @@ def model_explorer(
6364 ## load bioencoer config
6465 root_dir = config .root_dir
6566 run_name = config .run_name
66-
67- class_names = os .listdir (os .path .join (root_dir , "data" , run_name , "train" ))
68-
67+
6968 ## load config
7069 hyperparams = utils .load_yaml (config_path )
7170
@@ -79,7 +78,6 @@ def model_explorer(
7978
8079 ## get swa path
8180 ckpt_pretrained = os .path .join (root_dir , "weights" , run_name , stage , "swa" )
82-
8381 if stage == 'first' :
8482 vis_funcs = ['Filters' , 'Activations' , 'Saliency' ]
8583 else :
@@ -90,28 +88,30 @@ def model_explorer(
9088
9189 # Sidebar
9290 img_path = "https://github.com/agporto/BioEncoder/raw/main/assets/bioencoder_logo.png"
93- st .sidebar .image (img_path , use_column_width = True )
91+ st .sidebar .image (img_path , width = 'stretch' )
9492 st .sidebar .title ("BioEncoder Model Explorer" )
9593
9694 # Image upload
9795 uploaded_file = st .sidebar .file_uploader ("Upload an Image" , type = ["png" , "jpg" , "jpeg" ])
9896
99- ## get image transformations
100- transform = utils .get_transforms (hyperparams , no_aug = True )
101-
10297 # Load the model and add to cache
10398 model = load_model (
10499 ckpt_pretrained ,
105100 backbone ,
106101 num_classes ,
107102 stage
108103 )
109-
104+
105+ ## get class names
106+ train_folder = os .path .join (root_dir , "data" , run_name , "train" )
107+ train_folder_timm = ImageFolder (train_folder )
108+ class_names = train_folder_timm .classes
109+
110110 if uploaded_file is not None :
111111
112112 # Display the uploaded image
113113 image = Image .open (uploaded_file ).convert ('RGB' )
114- st .sidebar .image (image , caption = "Input Image" , use_column_width = True )
114+ st .sidebar .image (image , caption = "Input Image" , width = 'stretch' )
115115
116116 # resize image
117117 image_resized = image .resize ((img_size , img_size ))
0 commit comments