@@ -69,19 +69,21 @@ def interactive_plots(
6969 "plot_style" : hyperparams .get ("plot_style" , 1 ),
7070 "point_size" : hyperparams .get ("point_size" , 10 ),
7171 }
72-
72+
73+ return_results = hyperparams .get ("return_results" , False )
74+
7375 ## directories and file management
7476 data_dir = os .path .join (root_dir , "data" , run_name )
7577 plot_dir = os .path .join (root_dir , "plots" , run_name )
7678 os .makedirs (plot_dir , exist_ok = True )
7779 plot_path = os .path .join (plot_dir , "embeddings_interactive_plot.html" )
78- if not overwrite and not kwargs . get ( "ret_embeddings" ):
80+ if not overwrite and not ( return_embeddings or return_coords ):
7981 assert not os .path .isfile (plot_path ), f"File already exists: { plot_path } "
8082
8183 ## Load model and set up
8284 print (f"Checkpoint: using { checkpoint } of { stage } stage" )
8385 ckpt_pretrained = os .path .join (root_dir , "weights" , run_name , stage , checkpoint )
84- utils .set_seed ()
86+ seed = utils .set_seed ()
8587 model = utils .build_model (backbone , second_stage = (stage == "second" ), num_classes = num_classes , ckpt_pretrained = ckpt_pretrained ).cuda ()
8688 model .use_projection_head (False )
8789 model .eval ()
@@ -91,40 +93,57 @@ def interactive_plots(
9193 loaders = utils .build_loaders (
9294 data_dir , transforms , batch_sizes , num_workers ,
9395 second_stage = (stage == "second" ), drop_last = False , shuffle_train = False )
94- embeddings , labels , rel_paths = [], [], []
9596
96- ## val set - batch size cant be zero
97+ ## val set (always computed)
9798 embeddings_val , labels_val = utils .compute_embeddings (loaders ["valid_loader" ], model )
9899 rel_paths_val = [item [0 ][len (root_dir ) + 1 :] for item in loaders ["valid_loader" ].dataset .imgs ]
99- embeddings .extend (embeddings_val )
100- labels .extend (labels_val )
101- rel_paths .extend (rel_paths_val )
100+ # Build validation DataFrame (meta + embeddings)
101+ df_val_meta = pd .DataFrame ({
102+ "image_name" : [os .path .basename (p ) for p in rel_paths_val ],
103+ "class_str" : [os .path .basename (os .path .dirname (p )) for p in rel_paths_val ],
104+ "dataset" : "val" ,
105+ })
106+ df_embeddings = pd .concat ([df_val_meta , pd .DataFrame (embeddings_val )], axis = 1 )
102107
103108 ## train set - skipped if zero batch size
104109 if batch_sizes ["train_batch_size" ] is not None :
105110 embeddings_train , labels_train = utils .compute_embeddings (loaders ["train_loader" ], model )
106111 rel_paths_train = [item [0 ][len (root_dir ) + 1 :] for item in loaders ["train_loader" ].dataset .imgs ]
107- embeddings .extend (embeddings_train )
108- labels .extend (labels_train )
109- rel_paths .extend (rel_paths_train )
110-
111- ## Return embeddings without plotting
112- if kwargs .get ("ret_embeddings" ):
113- df = pd .DataFrame ({"image_name" : [os .path .basename (p ) for p in rel_paths ], "class" : [os .path .basename (os .path .dirname (p )) for p in rel_paths ]})
114- return pd .concat ([df , pd .DataFrame (embeddings )], axis = 1 )
115112
113+ # Build training DataFrame (meta + embeddings)
114+ df_train_meta = pd .DataFrame ({
115+ "image_name" : [os .path .basename (p ) for p in rel_paths_train ],
116+ "class_str" : [os .path .basename (os .path .dirname (p )) for p in rel_paths_train ],
117+ "dataset" : "train" ,
118+ })
119+ df_train = pd .concat ([df_train_meta , pd .DataFrame (embeddings_train )], axis = 1 )
120+ df_embeddings = pd .concat ([df_embeddings , df_train ], ignore_index = True )
121+
122+ ## Stable order before reduction
123+ df_embeddings = df_embeddings .sort_values (by = ["class_str" , "dataset" ,"image_name" ]).reset_index (drop = True )
124+
116125 ## Reduce dimensionality
117126 if not perplexity :
118- perplexity = min (30 , max (5 , (len (embeddings ) - 1 ) / 3 ))
119- print (f"tSNE: using a perplexity value of { perplexity } " )
120- reduced_data , colnames , _ = helpers .embbedings_dimension_reductions (embeddings , perplexity )
127+ perplexity = min (30 , max (5 , (len (df_embeddings ) - 1 ) / 3 ))
128+ print (f"tSNE: using perplexity { perplexity } " )
129+ # Reduce on numeric embedding columns only
130+ embedding_matrix = df_embeddings .select_dtypes (include = [np .number ])
131+ reduced_data , colnames , _ = helpers .embbedings_dimension_reductions (embedding_matrix , perplexity , seed )
121132
122133 ## make plot
123- df = pd .DataFrame (reduced_data , columns = colnames )
124- df ["paths" ] = [os .path .join (".." , ".." , p ) for p in rel_paths ]
125- df ["class" ], df ["class_str" ] = labels , [os .path .basename (os .path .dirname (p )) for p in rel_paths ]
126- df ["dataset" ] = df ["paths" ].apply (lambda x : "validation" if "/val/" in x else "train" )
127- helpers .bokeh_plot (df , out_path = plot_path , ** plot_config )
134+ df_plot = df_embeddings .select_dtypes (exclude = [np .number ])
135+ df_plot ['paths' ] = df_plot .apply (lambda row : os .path .join (
136+ ".." , ".." , "data" , run_name , row ['dataset' ], row ['class_str' ], row ['image_name' ]), axis = 1 )
137+ df_plot ["class" ] = pd .Categorical (df_plot ["class_str" ]).codes
138+ df_plot = pd .concat ([df_plot , pd .DataFrame (reduced_data , columns = colnames )], axis = 1 )
139+
140+ helpers .bokeh_plot (df_plot , out_path = plot_path , ** plot_config )
141+
142+ # Return logic: either one or both
143+ if return_results :
144+ return df_embeddings , df_plot
145+
146+
128147
129148
130149def cli ():
0 commit comments