@@ -46,94 +46,90 @@ def interactive_plots(
4646
4747 """
4848
49- ## load bioencoer config
50- root_dir = config .root_dir
51- run_name = config .run_name
52-
53- ## load config
49+ ## Load Bioencoder config
50+ root_dir , run_name = config .root_dir , config .run_name
5451 hyperparams = utils .load_yaml (config_path )
5552
56- ## parse config
53+ ## Parse config
5754 backbone = hyperparams ["model" ]["backbone" ]
5855 num_classes = hyperparams ["model" ].get ("num_classes" , None )
5956 checkpoint = hyperparams ["model" ].get ("checkpoint" , "swa" )
60- stage = hyperparams ["model" ].get ("stage" , "first" )
57+ stage = hyperparams .get ("model" , {}).get ("stage" , "first" )
58+
6159 batch_sizes = {
62- "train_batch_size" : hyperparams [ "dataloaders" ][ "train_batch_size" ] ,
63- "valid_batch_size" : hyperparams [ "dataloaders" ][ "valid_batch_size" ] ,
60+ "train_batch_size" : hyperparams . get ( "dataloaders" , {}). get ( "train_batch_size" ) ,
61+ "valid_batch_size" : hyperparams . get ( "dataloaders" , {}). get ( "valid_batch_size" , 1 ) ,
6462 }
65- num_workers = hyperparams ["dataloaders" ]["num_workers" ]
66- color_classes = hyperparams .get ("color_classes" , None )
67- color_map = hyperparams .get ("color_map" , "jet" )
68- plot_style = hyperparams .get ("plot_style" , 1 )
69- point_size = hyperparams .get ("point_size" , 10 )
70- perplexity = hyperparams .get ("perplexity" , None )
63+ num_workers = hyperparams .get ("dataloaders" , {}).get ("num_workers" , 4 )
64+ perplexity = hyperparams .get ("perplexity" , 30 )
7165
72- ## set up dirs
73- data_dir = os .path .join (root_dir ,"data" , run_name )
74- plot_dir = os .path .join (root_dir , "plots" , run_name )
75- os .makedirs (plot_dir , exist_ok = True )
66+ plot_config = {
67+ "color_classes" : hyperparams .get ("color_classes" , None ),
68+ "color_map" : hyperparams .get ("color_map" , "jet" ),
69+ "plot_style" : hyperparams .get ("plot_style" , 1 ),
70+ "point_size" : hyperparams .get ("point_size" , 10 ),
71+ }
72+
7673
77- ## plot path
78- plot_path = os .path .join (plot_dir , f"embeddings_{ run_name } .html" )
74+ ## Set up directories
75+ data_dir = os .path .join (root_dir , "data" , run_name )
76+ plot_path = os .path .join (root_dir , "plots" , run_name , f"embeddings_{ run_name } .html" )
7977 if not overwrite and not kwargs .get ("ret_embeddings" ):
8078 assert not os .path .isfile (plot_path ), f"File exists: { plot_path } "
8179
82- ## load weights
80+ ## Load model and set up
8381 print (f"Checkpoint: using { checkpoint } of { stage } stage" )
84- ckpt_pretrained = os .path .join (config .root_dir , "weights" , run_name , stage , checkpoint )
85-
86- ## set random seed
82+ ckpt_pretrained = os .path .join (root_dir , "weights" , run_name , stage , checkpoint )
8783 utils .set_seed ()
88-
89- ## extract embeddings
9084 transforms = utils .build_transforms (hyperparams )
91- loaders = utils .build_loaders (
92- data_dir , transforms , batch_sizes , num_workers , second_stage = (stage == "second" )
93- )
94- model = utils .build_model (
95- backbone ,
96- second_stage = (stage == "second" ),
97- num_classes = num_classes ,
98- ckpt_pretrained = ckpt_pretrained ,
99- ).cuda ()
85+ loaders = utils .build_loaders (data_dir , transforms , batch_sizes , num_workers , second_stage = (stage == "second" ))
86+ model = utils .build_model (backbone , second_stage = (stage == "second" ), num_classes = num_classes , ckpt_pretrained = ckpt_pretrained ).cuda ()
10087 model .use_projection_head (False )
10188 model .eval ()
102- embeddings_train , labels_train = utils .compute_embeddings (
103- loaders ["valid_loader" ], model
104- )
105-
106- ## load dataset
107- rel_paths_train = [item [0 ][len (root_dir ) + 1 :] for item in loaders ["valid_loader" ].dataset .imgs ]
108-
109- ## return embeddings without plotting
89+
90+ ## Determine which embeddings to compute
91+ embeddings , labels , rel_paths = [], [], []
92+
93+ ## val batch size cant be zero
94+ embeddings_val , labels_val = utils .compute_embeddings (loaders ["valid_loader" ], model )
95+ if len (embeddings_val ) < len (loaders ["valid_loader" ].dataset .imgs ):
96+ missed_imgs = len (loaders ["valid_loader" ].dataset .imgs ) - len (embeddings_val )
97+ print (f"Warning: missed { missed_imgs } images because batch size was not a multiple of validation dataset size." )
98+ rel_paths_val = [item [0 ][len (root_dir ) + 1 :] for item in loaders ["valid_loader" ].dataset .imgs [:len (embeddings_val )]]
99+ embeddings .extend (embeddings_val )
100+ labels .extend (labels_val )
101+ rel_paths .extend (rel_paths_val )
102+
103+ ## train set embeddings
104+ if batch_sizes ["train_batch_size" ] is not None :
105+ embeddings_train , labels_train = utils .compute_embeddings (loaders ["train_loader" ], model )
106+ if len (embeddings_train ) < len (loaders ["train_loader" ].dataset .imgs ):
107+ missed_imgs = len (loaders ["train_loader" ].dataset .imgs ) - len (embeddings_train )
108+ print (f"Warning: missed { missed_imgs } images because batch size was not a multiple of training dataset size." )
109+ rel_paths_train = [item [0 ][len (root_dir ) + 1 :] for item in loaders ["train_loader" ].dataset .imgs [:len (embeddings_train )]]
110+ embeddings .extend (embeddings_train )
111+ labels .extend (labels_train )
112+ rel_paths .extend (rel_paths_train )
113+
114+ ## Return embeddings without plotting
110115 if kwargs .get ("ret_embeddings" ):
116+ df = pd .DataFrame ({"image_name" : [os .path .basename (p ) for p in rel_paths ], "class" : [os .path .basename (os .path .dirname (p )) for p in rel_paths ]})
117+ return pd .concat ([df , pd .DataFrame (embeddings )], axis = 1 )
111118
112- df = pd .DataFrame ([os .path .basename (item ) for item in rel_paths_train ], columns = ["image_name" ])
113- df ["class" ] = [
114- os .path .basename (os .path .dirname (item [0 ])) for item in loaders ["valid_loader" ].dataset .imgs
115- ]
116- return pd .concat ([df , pd .DataFrame (embeddings_train )], axis = 1 )
117-
118- ## reduce dimensionality
119- perplexity = perplexity if perplexity else min (100 , len (embeddings_train ) // 2 )
120- reduced_data , colnames , _ = helpers .embbedings_dimension_reductions (
121- embeddings_train , perplexity
122- )
123- df = pd .DataFrame (reduced_data , columns = colnames )
124- df ["paths" ] = [ os .path .join (".." , ".." , item ) for item in rel_paths_train ]
125- df ["class" ] = labels_train
126- df ["class_str" ] = [
127- os .path .basename (os .path .dirname (item [0 ])) for item in loaders ["valid_loader" ].dataset .imgs
128- ]
129-
130- ## check if color matches n classes
131- if color_classes :
132- assert len (np .unique (labels_train )) == len (color_classes ), f"Number of classes is { len (np .unique (labels_train ))} , but you only provided { len (color_classes )} colors"
133-
134- helpers .bokeh_plot (df , out_path = plot_path , color_map = color_map , color_classes = color_classes ,
135- plot_style = plot_style , point_size = point_size )
119+ ## Reduce dimensionality
120+ if not perplexity :
121+ perplexity = min (100 , len (embeddings ) // 2 )
122+ print (f"tSNE: using a perplexity value of { perplexity } " )
123+ reduced_data , colnames , _ = helpers .embbedings_dimension_reductions (embeddings , perplexity )
136124
125+ ## make plot
126+ df = pd .DataFrame (reduced_data , columns = colnames )
127+ df ["paths" ] = [os .path .join (".." , ".." , p ) for p in rel_paths ]
128+ df ["class" ], df ["class_str" ] = labels , [os .path .basename (os .path .dirname (p )) for p in rel_paths ]
129+ df ["dataset" ] = df ["paths" ].apply (lambda x : "validation" if "/val/" in x else "train" )
130+
131+ helpers .bokeh_plot (df , out_path = plot_path , ** plot_config )
132+
137133
138134def cli ():
139135
0 commit comments