@@ -61,7 +61,7 @@ def interactive_plots(
6161 "valid_batch_size" : hyperparams .get ("dataloaders" , {}).get ("valid_batch_size" ,1 ),
6262 }
6363 num_workers = hyperparams .get ("dataloaders" , {}).get ("num_workers" , 4 )
64- perplexity = hyperparams .get ("perplexity" , 30 )
64+ perplexity = hyperparams .get ("perplexity" )
6565
6666 plot_config = {
6767 "color_classes" : hyperparams .get ("color_classes" , None ),
@@ -70,43 +70,40 @@ def interactive_plots(
7070 "point_size" : hyperparams .get ("point_size" , 10 ),
7171 }
7272
73-
74- ## Set up directories
73+ ## directories and file management
7574 data_dir = os .path .join (root_dir , "data" , run_name )
76- plot_path = os .path .join (root_dir , "plots" , run_name , f"embeddings_{ run_name } .html" )
75+ plot_dir = os .path .join (root_dir , "plots" , run_name )
76+ os .makedirs (plot_dir , exist_ok = True )
77+ plot_path = os .path .join (plot_dir , "embeddings_interactive_plot.html" )
7778 if not overwrite and not kwargs .get ("ret_embeddings" ):
78- assert not os .path .isfile (plot_path ), f"File exists: { plot_path } "
79+ assert not os .path .isfile (plot_path ), f"File already exists: { plot_path } "
7980
8081 ## Load model and set up
8182 print (f"Checkpoint: using { checkpoint } of { stage } stage" )
8283 ckpt_pretrained = os .path .join (root_dir , "weights" , run_name , stage , checkpoint )
8384 utils .set_seed ()
84- transforms = utils .build_transforms (hyperparams )
85- loaders = utils .build_loaders (data_dir , transforms , batch_sizes , num_workers , second_stage = (stage == "second" ))
8685 model = utils .build_model (backbone , second_stage = (stage == "second" ), num_classes = num_classes , ckpt_pretrained = ckpt_pretrained ).cuda ()
8786 model .use_projection_head (False )
8887 model .eval ()
8988
90- ## Determine which embeddings to compute
89+ ## prep computation
90+ transforms = utils .build_transforms (hyperparams )
91+ loaders = utils .build_loaders (
92+ data_dir , transforms , batch_sizes , num_workers ,
93+ second_stage = (stage == "second" ), drop_last = False , shuffle_train = False )
9194 embeddings , labels , rel_paths = [], [], []
9295
93- ## val batch size cant be zero
96+ ## val set - batch size cant be zero
9497 embeddings_val , labels_val = utils .compute_embeddings (loaders ["valid_loader" ], model )
95- if len (embeddings_val ) < len (loaders ["valid_loader" ].dataset .imgs ):
96- missed_imgs = len (loaders ["valid_loader" ].dataset .imgs ) - len (embeddings_val )
97- print (f"Warning: missed { missed_imgs } images because batch size was not a multiple of validation dataset size." )
98- rel_paths_val = [item [0 ][len (root_dir ) + 1 :] for item in loaders ["valid_loader" ].dataset .imgs [:len (embeddings_val )]]
98+ rel_paths_val = [item [0 ][len (root_dir ) + 1 :] for item in loaders ["valid_loader" ].dataset .imgs ]
9999 embeddings .extend (embeddings_val )
100100 labels .extend (labels_val )
101101 rel_paths .extend (rel_paths_val )
102102
103- ## train set embeddings
103+ ## train set - skipped if zero batch size
104104 if batch_sizes ["train_batch_size" ] is not None :
105105 embeddings_train , labels_train = utils .compute_embeddings (loaders ["train_loader" ], model )
106- if len (embeddings_train ) < len (loaders ["train_loader" ].dataset .imgs ):
107- missed_imgs = len (loaders ["train_loader" ].dataset .imgs ) - len (embeddings_train )
108- print (f"Warning: missed { missed_imgs } images because batch size was not a multiple of training dataset size." )
109- rel_paths_train = [item [0 ][len (root_dir ) + 1 :] for item in loaders ["train_loader" ].dataset .imgs [:len (embeddings_train )]]
106+ rel_paths_train = [item [0 ][len (root_dir ) + 1 :] for item in loaders ["train_loader" ].dataset .imgs ]
110107 embeddings .extend (embeddings_train )
111108 labels .extend (labels_train )
112109 rel_paths .extend (rel_paths_train )
@@ -118,7 +115,7 @@ def interactive_plots(
118115
119116 ## Reduce dimensionality
120117 if not perplexity :
121- perplexity = min (100 , len (embeddings ) // 2 )
118+ perplexity = min (30 , max ( 5 , ( len (embeddings ) - 1 ) / 3 ) )
122119 print (f"tSNE: using a perplexity value of { perplexity } " )
123120 reduced_data , colnames , _ = helpers .embbedings_dimension_reductions (embeddings , perplexity )
124121
@@ -127,7 +124,6 @@ def interactive_plots(
127124 df ["paths" ] = [os .path .join (".." , ".." , p ) for p in rel_paths ]
128125 df ["class" ], df ["class_str" ] = labels , [os .path .basename (os .path .dirname (p )) for p in rel_paths ]
129126 df ["dataset" ] = df ["paths" ].apply (lambda x : "validation" if "/val/" in x else "train" )
130-
131127 helpers .bokeh_plot (df , out_path = plot_path , ** plot_config )
132128
133129
0 commit comments