@@ -1612,20 +1612,22 @@ def _karyotype_tags_n_alt(gt, alts, inversion_alts):
16121612 return inv_n_alt
16131613
16141614
1615- def prep_samples_for_cohort_grouping (* , df_samples , area_by , period_by ):
1615+ def prep_samples_for_cohort_grouping (
1616+ * , df_samples , area_by , period_by , taxon_by = "taxon"
1617+ ):
16161618 # Take a copy, as we will modify the dataframe.
16171619 df_samples = df_samples .copy ()
16181620
16191621 # Fix "intermediate" or "unassigned" taxon values - we only want to build
16201622 # cohorts with clean taxon calls, so we set other values to None.
16211623 loc_intermediate_taxon = (
1622- df_samples ["taxon" ].str .startswith ("intermediate" ).fillna (False )
1624+ df_samples [taxon_by ].str .startswith ("intermediate" ).fillna (False )
16231625 )
1624- df_samples .loc [loc_intermediate_taxon , "taxon" ] = None
1626+ df_samples .loc [loc_intermediate_taxon , taxon_by ] = None
16251627 loc_unassigned_taxon = (
1626- df_samples ["taxon" ].str .startswith ("unassigned" ).fillna (False )
1628+ df_samples [taxon_by ].str .startswith ("unassigned" ).fillna (False )
16271629 )
1628- df_samples .loc [loc_unassigned_taxon , "taxon" ] = None
1630+ df_samples .loc [loc_unassigned_taxon , taxon_by ] = None
16291631
16301632 # Add period column.
16311633 if period_by == "year" :
@@ -1647,7 +1649,9 @@ def prep_samples_for_cohort_grouping(*, df_samples, area_by, period_by):
16471649 return df_samples
16481650
16491651
1650- def build_cohorts_from_sample_grouping (* , group_samples_by_cohort , min_cohort_size ):
1652+ def build_cohorts_from_sample_grouping (
1653+ * , group_samples_by_cohort , min_cohort_size , taxon_by = "taxon"
1654+ ):
16511655 # Build cohorts dataframe.
16521656 df_cohorts = group_samples_by_cohort .agg (
16531657 size = ("sample_id" , len ),
@@ -1669,7 +1673,7 @@ def build_cohorts_from_sample_grouping(*, group_samples_by_cohort, min_cohort_si
16691673 # Create a label that is similar to the cohort metadata,
16701674 # although this won't be perfect.
16711675 df_cohorts ["label" ] = df_cohorts .apply (
1672- lambda v : f"{ v .area } _{ v . taxon [:4 ]} _{ v .period } " , axis = "columns"
1676+ lambda v : f"{ v .area } _{ v [ taxon_by ] [:4 ]} _{ v .period } " , axis = "columns"
16731677 )
16741678
16751679 # Apply minimum cohort size.
0 commit comments