Skip to content

Commit 658c538

Browse files
Improve color parameter handling in plot_haplotype_network function
1 parent a85cf3d commit 658c538

1 file changed

Lines changed: 94 additions & 46 deletions

File tree

malariagen_data/anopheles.py

Lines changed: 94 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -2232,55 +2232,103 @@ def plot_haplotype_network(
22322232
edges = np.triu(edges)
22332233
alt_edges = np.triu(alt_edges)
22342234

2235-
debug("setup colors")
2236-
color_values = None
2237-
color_values_display = None
2238-
color_discrete_map_display = None
2239-
ht_color_counts = None
2240-
if color is not None:
2241-
# Check if the color column exists in the dataframe
2242-
if color not in df_haps.columns:
2243-
raise ValueError(
2244-
f"Column '{color}' specified for coloring not found in sample data. Available columns: {', '.join(df_haps.columns)}"
2245-
)
2246-
# sanitise color column - necessary to avoid grey pie chart segments
2247-
df_haps["partition"] = (
2248-
df_haps[color].str.replace(r"\W", "", regex=True)
2249-
if color in df_haps.columns
2250-
else None
2251-
)
2235+
debug("setup colors")
2236+
color_values = None
2237+
color_values_display = None
2238+
color_discrete_map_display = None
2239+
ht_color_counts = None
2240+
2241+
if color is not None:
2242+
# Handle string case (direct column name or cohorts_ prefix)
2243+
if isinstance(color, str):
2244+
# Try direct column name
2245+
if color in df_haps.columns:
2246+
color_column = color
2247+
# Try with cohorts_ prefix
2248+
elif f"cohorts_{color}" in df_haps.columns:
2249+
color_column = f"cohorts_{color}"
2250+
# Neither exists, raise helpful error
2251+
else:
2252+
available_columns = ", ".join(df_haps.columns)
2253+
raise ValueError(
2254+
f"Column '{color}' or 'cohorts_{color}' not found in sample data. "
2255+
f"Available columns: {available_columns}"
2256+
)
2257+
2258+
# Now use the validated color_column for processing
2259+
df_haps["partition"] = df_haps[color_column].str.replace(
2260+
r"\W", "", regex=True
2261+
)
22522262

2253-
# extract all unique values of the color column
2254-
color_values = df_haps["partition"].fillna("<NA>").unique()
2255-
color_values_mapping = dict(zip(df_haps["partition"], df_haps[color]))
2256-
color_values_mapping["<NA>"] = "black"
2257-
color_values_display = [color_values_mapping[c] for c in color_values]
2263+
# extract all unique values of the color column
2264+
color_values = df_haps["partition"].fillna("<NA>").unique()
2265+
color_values_mapping = dict(
2266+
zip(df_haps["partition"], df_haps[color_column])
2267+
)
2268+
color_values_mapping["<NA>"] = "black"
2269+
color_values_display = [
2270+
color_values_mapping[c] for c in color_values
2271+
]
2272+
2273+
# Handle mapping/dictionary case
2274+
elif isinstance(color, Mapping):
2275+
# For mapping case, we need to create a new column based on the mapping
2276+
# Initialize with None
2277+
df_haps["partition"] = None
2278+
2279+
# Apply each query in the mapping to create the partition column
2280+
for label, query in color.items():
2281+
# Apply the query and assign the label to matching rows
2282+
mask = df_haps.eval(query)
2283+
df_haps.loc[mask, "partition"] = label
2284+
2285+
# Clean up the partition column to avoid issues with special characters
2286+
if df_haps["partition"].notna().any():
2287+
df_haps["partition"] = df_haps["partition"].str.replace(
2288+
r"\W", "", regex=True
2289+
)
2290+
2291+
# extract all unique values of the color column
2292+
color_values = df_haps["partition"].fillna("<NA>").unique()
2293+
# For mapping case, use partition values directly as they're already the labels
2294+
color_values_mapping = dict(
2295+
zip(df_haps["partition"], df_haps["partition"])
2296+
)
2297+
color_values_mapping["<NA>"] = "black"
2298+
color_values_display = [
2299+
color_values_mapping[c] for c in color_values
2300+
]
2301+
else:
2302+
# Invalid type
2303+
raise TypeError(
2304+
f"Expected color parameter to be a string or mapping, got {type(color).__name__}"
2305+
)
22582306

2259-
# count color values for each distinct haplotype
2260-
ht_color_counts = [
2261-
df_haps.iloc[list(s)]["partition"].value_counts().to_dict()
2262-
for s in ht_distinct_sets
2263-
]
2307+
# count color values for each distinct haplotype (same for both string and mapping cases)
2308+
ht_color_counts = [
2309+
df_haps.iloc[list(s)]["partition"].value_counts().to_dict()
2310+
for s in ht_distinct_sets
2311+
]
22642312

2265-
# Set up colors.
2266-
(
2267-
color_prepped,
2268-
color_discrete_map_prepped,
2269-
category_orders_prepped,
2270-
) = self._setup_sample_colors_plotly(
2271-
data=df_haps,
2272-
color="partition",
2273-
color_discrete_map=color_discrete_map,
2274-
color_discrete_sequence=color_discrete_sequence,
2275-
category_orders=category_orders,
2276-
)
2277-
del color_discrete_map
2278-
del color_discrete_sequence
2279-
del category_orders
2280-
color_discrete_map_display = {
2281-
color_values_mapping[v]: c
2282-
for v, c in color_discrete_map_prepped.items()
2283-
}
2313+
# Set up colors (same for both string and mapping cases)
2314+
(
2315+
color_prepped,
2316+
color_discrete_map_prepped,
2317+
category_orders_prepped,
2318+
) = self._setup_sample_colors_plotly(
2319+
data=df_haps,
2320+
color="partition",
2321+
color_discrete_map=color_discrete_map,
2322+
color_discrete_sequence=color_discrete_sequence,
2323+
category_orders=category_orders,
2324+
)
2325+
del color_discrete_map
2326+
del color_discrete_sequence
2327+
del category_orders
2328+
color_discrete_map_display = {
2329+
color_values_mapping[v]: c
2330+
for v, c in color_discrete_map_prepped.items()
2331+
}
22842332

22852333
debug("construct graph")
22862334
anon_width = np.sqrt(0.3 * node_size_factor)

0 commit comments

Comments
 (0)