@@ -2232,55 +2232,103 @@ def plot_haplotype_network(
22322232 edges = np .triu (edges )
22332233 alt_edges = np .triu (alt_edges )
22342234
2235- debug ("setup colors" )
2236- color_values = None
2237- color_values_display = None
2238- color_discrete_map_display = None
2239- ht_color_counts = None
2240- if color is not None :
2241- # Check if the color column exists in the dataframe
2242- if color not in df_haps .columns :
2243- raise ValueError (
2244- f"Column '{ color } ' specified for coloring not found in sample data. Available columns: { ', ' .join (df_haps .columns )} "
2245- )
2246- # sanitise color column - necessary to avoid grey pie chart segments
2247- df_haps ["partition" ] = (
2248- df_haps [color ].str .replace (r"\W" , "" , regex = True )
2249- if color in df_haps .columns
2250- else None
2251- )
2235+ debug ("setup colors" )
2236+ color_values = None
2237+ color_values_display = None
2238+ color_discrete_map_display = None
2239+ ht_color_counts = None
2240+
2241+ if color is not None :
2242+ # Handle string case (direct column name or cohorts_ prefix)
2243+ if isinstance (color , str ):
2244+ # Try direct column name
2245+ if color in df_haps .columns :
2246+ color_column = color
2247+ # Try with cohorts_ prefix
2248+ elif f"cohorts_{ color } " in df_haps .columns :
2249+ color_column = f"cohorts_{ color } "
2250+ # Neither exists, raise helpful error
2251+ else :
2252+ available_columns = ", " .join (df_haps .columns )
2253+ raise ValueError (
2254+ f"Column '{ color } ' or 'cohorts_{ color } ' not found in sample data. "
2255+ f"Available columns: { available_columns } "
2256+ )
2257+
2258+ # Now use the validated color_column for processing
2259+ df_haps ["partition" ] = df_haps [color_column ].str .replace (
2260+ r"\W" , "" , regex = True
2261+ )
22522262
2253- # extract all unique values of the color column
2254- color_values = df_haps ["partition" ].fillna ("<NA>" ).unique ()
2255- color_values_mapping = dict (zip (df_haps ["partition" ], df_haps [color ]))
2256- color_values_mapping ["<NA>" ] = "black"
2257- color_values_display = [color_values_mapping [c ] for c in color_values ]
2263+ # extract all unique values of the color column
2264+ color_values = df_haps ["partition" ].fillna ("<NA>" ).unique ()
2265+ color_values_mapping = dict (
2266+ zip (df_haps ["partition" ], df_haps [color_column ])
2267+ )
2268+ color_values_mapping ["<NA>" ] = "black"
2269+ color_values_display = [
2270+ color_values_mapping [c ] for c in color_values
2271+ ]
2272+
2273+ # Handle mapping/dictionary case
2274+ elif isinstance (color , Mapping ):
2275+ # For mapping case, we need to create a new column based on the mapping
2276+ # Initialize with None
2277+ df_haps ["partition" ] = None
2278+
2279+ # Apply each query in the mapping to create the partition column
2280+ for label , query in color .items ():
2281+ # Apply the query and assign the label to matching rows
2282+ mask = df_haps .eval (query )
2283+ df_haps .loc [mask , "partition" ] = label
2284+
2285+ # Clean up the partition column to avoid issues with special characters
2286+ if df_haps ["partition" ].notna ().any ():
2287+ df_haps ["partition" ] = df_haps ["partition" ].str .replace (
2288+ r"\W" , "" , regex = True
2289+ )
2290+
2291+ # extract all unique values of the color column
2292+ color_values = df_haps ["partition" ].fillna ("<NA>" ).unique ()
2293+ # For mapping case, use partition values directly as they're already the labels
2294+ color_values_mapping = dict (
2295+ zip (df_haps ["partition" ], df_haps ["partition" ])
2296+ )
2297+ color_values_mapping ["<NA>" ] = "black"
2298+ color_values_display = [
2299+ color_values_mapping [c ] for c in color_values
2300+ ]
2301+ else :
2302+ # Invalid type
2303+ raise TypeError (
2304+ f"Expected color parameter to be a string or mapping, got { type (color ).__name__ } "
2305+ )
22582306
2259- # count color values for each distinct haplotype
2260- ht_color_counts = [
2261- df_haps .iloc [list (s )]["partition" ].value_counts ().to_dict ()
2262- for s in ht_distinct_sets
2263- ]
2307+ # count color values for each distinct haplotype (same for both string and mapping cases)
2308+ ht_color_counts = [
2309+ df_haps .iloc [list (s )]["partition" ].value_counts ().to_dict ()
2310+ for s in ht_distinct_sets
2311+ ]
22642312
2265- # Set up colors.
2266- (
2267- color_prepped ,
2268- color_discrete_map_prepped ,
2269- category_orders_prepped ,
2270- ) = self ._setup_sample_colors_plotly (
2271- data = df_haps ,
2272- color = "partition" ,
2273- color_discrete_map = color_discrete_map ,
2274- color_discrete_sequence = color_discrete_sequence ,
2275- category_orders = category_orders ,
2276- )
2277- del color_discrete_map
2278- del color_discrete_sequence
2279- del category_orders
2280- color_discrete_map_display = {
2281- color_values_mapping [v ]: c
2282- for v , c in color_discrete_map_prepped .items ()
2283- }
2313+ # Set up colors (same for both string and mapping cases)
2314+ (
2315+ color_prepped ,
2316+ color_discrete_map_prepped ,
2317+ category_orders_prepped ,
2318+ ) = self ._setup_sample_colors_plotly (
2319+ data = df_haps ,
2320+ color = "partition" ,
2321+ color_discrete_map = color_discrete_map ,
2322+ color_discrete_sequence = color_discrete_sequence ,
2323+ category_orders = category_orders ,
2324+ )
2325+ del color_discrete_map
2326+ del color_discrete_sequence
2327+ del category_orders
2328+ color_discrete_map_display = {
2329+ color_values_mapping [v ]: c
2330+ for v , c in color_discrete_map_prepped .items ()
2331+ }
22842332
22852333 debug ("construct graph" )
22862334 anon_width = np .sqrt (0.3 * node_size_factor )
0 commit comments