11import numpy as np
22import pandas as pd
3+ import re
34import xarray as xr
45import plotly .express as px
56from textwrap import dedent
1415from .base import AnophelesBase
1516
1617
17- def prep_samples_for_cohort_grouping (* , df_samples , area_by , period_by ):
18+ def prep_samples_for_cohort_grouping (* , df_samples , area_by , period_by , taxon_by ):
1819 # Take a copy, as we will modify the dataframe.
1920 df_samples = df_samples .copy ()
2021
2122 # Fix "intermediate" or "unassigned" taxon values - we only want to build
2223 # cohorts with clean taxon calls, so we set other values to None.
2324 loc_intermediate_taxon = (
24- df_samples ["taxon" ].str .startswith ("intermediate" ).fillna (False )
25+ df_samples [taxon_by ].str .startswith ("intermediate" ).fillna (False )
2526 )
26- df_samples .loc [loc_intermediate_taxon , "taxon" ] = None
27+ df_samples .loc [loc_intermediate_taxon , taxon_by ] = None
2728 loc_unassigned_taxon = (
28- df_samples ["taxon" ].str .startswith ("unassigned" ).fillna (False )
29+ df_samples [taxon_by ].str .startswith ("unassigned" ).fillna (False )
2930 )
30- df_samples .loc [loc_unassigned_taxon , "taxon" ] = None
31+ df_samples .loc [loc_unassigned_taxon , taxon_by ] = None
3132
3233 # Add period column.
33- if period_by == "year" :
34- make_period = _make_sample_period_year
35- elif period_by == "quarter" :
36- make_period = _make_sample_period_quarter
37- elif period_by == "month" :
38- make_period = _make_sample_period_month
39- else : # pragma: no cover
40- raise ValueError (
41- f"Value for period_by parameter must be one of 'year', 'quarter', 'month'; found { period_by !r} ."
42- )
43- sample_period = df_samples .apply (make_period , axis = "columns" )
44- df_samples ["period" ] = sample_period
4534
46- # Add area column for consistent output.
35+ # Map supported period_by values to functions that return either the relevant pd.Period or pd.NaT per row.
36+ period_by_funcs = {
37+ "year" : _make_sample_period_year ,
38+ "quarter" : _make_sample_period_quarter ,
39+ "month" : _make_sample_period_month ,
40+ }
41+
42+ # Get the matching function for the specified period_by value, or None.
43+ period_by_func = period_by_funcs .get (period_by )
44+
45+ # If there were no matching functions for the specified period_by value...
46+ if period_by_func is None :
47+ # Raise a ValueError if the specified period_by value is not a column in the DataFrame.
48+ if period_by not in df_samples .columns :
49+ raise ValueError (
50+ f"Invalid value for `period_by`: { period_by !r} . Either specify the name of an existing column "
51+ "or a supported period: 'year', 'quarter', or 'month'."
52+ )
53+
54+ # Raise a ValueError if the specified period_by column does not contain instances pd.Period.
55+ if not all (
56+ df_samples [period_by ].apply (
57+ lambda value : pd .isnull (value ) or isinstance (value , pd .Period )
58+ )
59+ ):
60+ raise TypeError (
61+ f"Invalid values in { period_by !r} column. Must be either pandas.Period or null."
62+ )
63+
64+ # Copy the specified period_by column to a new "period" column.
65+ df_samples ["period" ] = df_samples [period_by ]
66+ else :
67+ # Apply the matching period_by function to create a new "period" column.
68+ df_samples ["period" ] = df_samples .apply (period_by_func , axis = "columns" )
69+
70+ # Copy the specified area_by column to a new "area" column.
4771 df_samples ["area" ] = df_samples [area_by ]
4872
4973 return df_samples
5074
5175
52- def build_cohorts_from_sample_grouping (* , group_samples_by_cohort , min_cohort_size ):
76+ def build_cohorts_from_sample_grouping (
77+ * , group_samples_by_cohort , min_cohort_size , taxon_by
78+ ):
5379 # Build cohorts dataframe.
5480 df_cohorts = group_samples_by_cohort .agg (
5581 size = ("sample_id" , len ),
@@ -70,9 +96,16 @@ def build_cohorts_from_sample_grouping(*, group_samples_by_cohort, min_cohort_si
7096 df_cohorts ["period_end" ] = cohort_period_end
7197 # Create a label that is similar to the cohort metadata,
7298 # although this won't be perfect.
73- df_cohorts ["label" ] = df_cohorts .apply (
74- lambda v : f"{ v .area } _{ v .taxon [:4 ]} _{ v .period } " , axis = "columns"
75- )
99+ if taxon_by == frq_params .taxon_by_default :
100+ df_cohorts ["label" ] = df_cohorts .apply (
101+ lambda v : f"{ v .area } _{ v [taxon_by ][:4 ]} _{ v .period } " , axis = "columns"
102+ )
103+ else :
104+ # Replace non-alphanumeric characters in the taxon with underscores.
105+ df_cohorts ["label" ] = df_cohorts .apply (
106+ lambda v : f"{ v .area } _{ re .sub (r'[^A-Za-z0-9]+' , '_' , str (v [taxon_by ]))} _{ v .period } " ,
107+ axis = "columns" ,
108+ )
76109
77110 # Apply minimum cohort size.
78111 df_cohorts = df_cohorts .query (f"size >= { min_cohort_size } " ).reset_index (drop = True )
0 commit comments