@@ -43,13 +43,28 @@ def __init__(
4343 `site_class`, `cohort_size`, `min_cohort_size`, `max_cohort_size`,
4444 `random_seed`.
4545
46+ .. versionchanged:: 9.0.0
47+ The `cohorts` parameter has been added to enable cohort-based
48+ downsampling via the `max_cohort_size` parameter.
4649 """ ,
4750 returns = ("df_pca" , "evr" ),
4851 notes = """
4952 This computation may take some time to run, depending on your computing
5053 environment. Results of this computation will be cached and re-used if
5154 the `results_cache` parameter was set when instantiating the API client.
5255 """ ,
56+ examples = """
57+ Run a PCA, downsampling to a maximum of 20 samples per country::
58+
59+ >>> import malariagen_data
60+ >>> ag3 = malariagen_data.Ag3()
61+ >>> df_pca, evr = ag3.pca(
62+ ... region="3R",
63+ ... n_snps=1000,
64+ ... cohorts="country",
65+ ... max_cohort_size=20,
66+ ... )
67+ """ ,
5368 )
5469 def pca (
5570 self ,
@@ -61,6 +76,10 @@ def pca(
6176 sample_query : Optional [base_params .sample_query ] = None ,
6277 sample_query_options : Optional [base_params .sample_query_options ] = None ,
6378 sample_indices : Optional [base_params .sample_indices ] = None ,
79+ cohorts : Optional [base_params .cohorts ] = None ,
80+ cohort_size : Optional [base_params .cohort_size ] = None ,
81+ min_cohort_size : Optional [base_params .min_cohort_size ] = None ,
82+ max_cohort_size : Optional [base_params .max_cohort_size ] = None ,
6483 site_mask : Optional [base_params .site_mask ] = base_params .DEFAULT ,
6584 site_class : Optional [base_params .site_class ] = None ,
6685 min_minor_ac : Optional [
@@ -69,9 +88,6 @@ def pca(
6988 max_missing_an : Optional [
7089 base_params .max_missing_an
7190 ] = pca_params .max_missing_an_default ,
72- cohort_size : Optional [base_params .cohort_size ] = None ,
73- min_cohort_size : Optional [base_params .min_cohort_size ] = None ,
74- max_cohort_size : Optional [base_params .max_cohort_size ] = None ,
7591 exclude_samples : Optional [base_params .samples ] = None ,
7692 fit_exclude_samples : Optional [base_params .samples ] = None ,
7793 random_seed : base_params .random_seed = 42 ,
@@ -82,6 +98,41 @@ def pca(
8298 # invalidate any previously cached data.
8399 name = "pca_v4"
84100
101+ # Handle cohort downsampling.
102+ if cohorts is not None :
103+ if max_cohort_size is None :
104+ raise ValueError ("`max_cohort_size` is required when `cohorts` is provided." )
105+ if sample_indices is not None :
106+ raise ValueError (
107+ "Cannot use `sample_indices` with `cohorts` and `max_cohort_size`."
108+ )
109+ if cohort_size is not None or min_cohort_size is not None :
110+ raise ValueError (
111+ "Cannot use `cohort_size` or `min_cohort_size` with `cohorts`."
112+ )
113+ df_samples = self .sample_metadata (
114+ sample_sets = sample_sets ,
115+ sample_query = sample_query ,
116+ sample_query_options = sample_query_options ,
117+ )
118+ # N.B., we are going to overwrite the sample_indices parameter here.
119+ groups = df_samples .groupby (cohorts , sort = False )
120+ ix = []
121+ for _ , group in groups :
122+ if len (group ) > max_cohort_size :
123+ ix .extend (
124+ group .sample (
125+ n = max_cohort_size , random_state = random_seed , replace = False
126+ ).index
127+ )
128+ else :
129+ ix .extend (group .index )
130+ sample_indices = ix
131+ # From this point onwards, the sample_query is no longer needed, because
132+ # the sample selection is defined by the sample_indices.
133+ sample_query = None
134+ sample_query_options = None
135+
85136 # Normalize params for consistent hash value.
86137 (
87138 sample_sets_prepped ,
@@ -105,6 +156,7 @@ def pca(
105156 min_minor_ac = min_minor_ac ,
106157 max_missing_an = max_missing_an ,
107158 n_components = n_components ,
159+ cohorts = cohorts ,
108160 cohort_size = cohort_size ,
109161 min_cohort_size = min_cohort_size ,
110162 max_cohort_size = max_cohort_size ,
@@ -122,10 +174,10 @@ def pca(
122174 self .results_cache_set (name = name , params = params , results = results )
123175
124176 # Unpack results.
125- coords = results ["coords" ]
126- evr = results ["evr" ]
127- samples = results ["samples" ]
128- loc_keep_fit = results ["loc_keep_fit" ]
177+ coords = np . array ( results ["coords" ])
178+ evr = np . array ( results ["evr" ])
179+ samples = np . array ( results ["samples" ])
180+ loc_keep_fit = np . array ( results ["loc_keep_fit" ])
129181
130182 # Load sample metadata.
131183 df_samples = self .sample_metadata (
@@ -166,6 +218,7 @@ def _pca(
166218 random_seed ,
167219 chunks ,
168220 inline_array ,
221+ ** kwargs ,
169222 ):
170223 # Load diplotypes.
171224 gn , samples = self .biallelic_diplotypes (
0 commit comments