@@ -936,8 +936,8 @@ def plot_frequencies_time_series(
936936 legend_sizing : plotly_params .legend_sizing = "constant" ,
937937 show : plotly_params .show = True ,
938938 renderer : plotly_params .renderer = None ,
939- taxon : plotly_params . taxon = None ,
940- area : plotly_params . area = None ,
939+ taxa : frq_params . taxa = None ,
940+ areas : frq_params . areas = None ,
941941 ** kwargs ,
942942 ) -> plotly_params .figure :
943943 # Handle title.
@@ -949,13 +949,17 @@ def plot_frequencies_time_series(
949949 df_cohorts = ds [cohort_vars ].to_dataframe ()
950950 df_cohorts .columns = [c .split ("cohort_" )[1 ] for c in df_cohorts .columns ] # type: ignore
951951
952- # If specified, restrict the dataframe by taxon.
953- if taxon :
954- df_cohorts = df_cohorts [df_cohorts ["taxon" ] == taxon ]
955-
956- # If specified, restrict the dataframe by area.
957- if area :
958- df_cohorts = df_cohorts [df_cohorts ["area" ] == area ]
952+ # If specified, restrict the dataframe by taxa.
953+ if isinstance (taxa , str ):
954+ df_cohorts = df_cohorts [df_cohorts ["taxon" ] == taxa ]
955+ elif isinstance (taxa , (list , tuple )):
956+ df_cohorts = df_cohorts [df_cohorts ["taxon" ].isin (taxa )]
957+
958+ # If specified, restrict the dataframe by areas.
959+ if isinstance (areas , str ):
960+ df_cohorts = df_cohorts [df_cohorts ["area" ] == areas ]
961+ elif isinstance (areas , (list , tuple )):
962+ df_cohorts = df_cohorts [df_cohorts ["area" ].isin (areas )]
959963
960964 # Extract variant labels.
961965 variant_labels = ds ["variant_label" ].values
0 commit comments