@@ -42,7 +42,6 @@ def __init__(
4242 The following additional parameters were also added in version 8.0.0:
4343 `site_class`, `cohort_size`, `min_cohort_size`, `max_cohort_size`,
4444 `random_seed`.
45-
4645 """ ,
4746 parameters = dict (
4847 imputation_method = """
@@ -69,6 +68,10 @@ def pca(
6968 sample_query : Optional [base_params .sample_query ] = None ,
7069 sample_query_options : Optional [base_params .sample_query_options ] = None ,
7170 sample_indices : Optional [base_params .sample_indices ] = None ,
71+ cohorts : Optional [base_params .cohorts ] = None ,
72+ cohort_size : Optional [base_params .cohort_size ] = None ,
73+ min_cohort_size : Optional [base_params .min_cohort_size ] = None ,
74+ max_cohort_size : Optional [base_params .max_cohort_size ] = None ,
7275 site_mask : Optional [base_params .site_mask ] = base_params .DEFAULT ,
7376 site_class : Optional [base_params .site_class ] = None ,
7477 min_minor_ac : Optional [
@@ -78,9 +81,6 @@ def pca(
7881 base_params .max_missing_an
7982 ] = pca_params .max_missing_an_default ,
8083 imputation_method : pca_params .imputation_method = pca_params .imputation_method_default ,
81- cohort_size : Optional [base_params .cohort_size ] = None ,
82- min_cohort_size : Optional [base_params .min_cohort_size ] = None ,
83- max_cohort_size : Optional [base_params .max_cohort_size ] = None ,
8484 exclude_samples : Optional [base_params .samples ] = None ,
8585 fit_exclude_samples : Optional [base_params .samples ] = None ,
8686 random_seed : base_params .random_seed = 42 ,
@@ -98,8 +98,44 @@ def pca(
9898
9999 ## Normalize params for consistent hash value.
100100
101- # Note: `_prep_sample_selection_cache_params` converts `sample_query` and `sample_query_options` into `sample_indices`.
102- # So `sample_query` and `sample_query_options` should not be used beyond this point. (`sample_indices` should be used instead.)
101+ # Handle cohort downsampling.
102+ if cohorts is not None :
103+ if max_cohort_size is None :
104+ raise ValueError (
105+ "`max_cohort_size` is required when `cohorts` is provided."
106+ )
107+ if sample_indices is not None :
108+ raise ValueError (
109+ "Cannot use `sample_indices` with `cohorts` and `max_cohort_size`."
110+ )
111+ if cohort_size is not None or min_cohort_size is not None :
112+ raise ValueError (
113+ "Cannot use `cohort_size` or `min_cohort_size` with `cohorts`."
114+ )
115+ df_samples = self .sample_metadata (
116+ sample_sets = sample_sets ,
117+ sample_query = sample_query ,
118+ sample_query_options = sample_query_options ,
119+ )
120+ # N.B., we are going to overwrite the sample_indices parameter here.
121+ groups = df_samples .groupby (cohorts , sort = False )
122+ ix = []
123+ for _ , group in groups :
124+ if len (group ) > max_cohort_size :
125+ ix .extend (
126+ group .sample (
127+ n = max_cohort_size , random_state = random_seed , replace = False
128+ ).index
129+ )
130+ else :
131+ ix .extend (group .index )
132+ sample_indices = ix
133+ # From this point onwards, the sample_query is no longer needed, because
134+ # the sample selection is defined by the sample_indices.
135+ sample_query = None
136+ sample_query_options = None
137+
138+ # Normalize params for consistent hash value.
103139 (
104140 prepared_sample_sets ,
105141 prepared_sample_indices ,
@@ -132,6 +168,7 @@ def pca(
132168 max_missing_an = max_missing_an ,
133169 imputation_method = imputation_method ,
134170 n_components = n_components ,
171+ cohorts = cohorts ,
135172 cohort_size = cohort_size ,
136173 min_cohort_size = min_cohort_size ,
137174 max_cohort_size = max_cohort_size ,
@@ -149,10 +186,10 @@ def pca(
149186 self .results_cache_set (name = name , params = params , results = results )
150187
151188 # Unpack results.
152- coords = results ["coords" ]
153- evr = results ["evr" ]
154- samples = results ["samples" ]
155- loc_keep_fit = results ["loc_keep_fit" ]
189+ coords = np . array ( results ["coords" ])
190+ evr = np . array ( results ["evr" ])
191+ samples = np . array ( results ["samples" ])
192+ loc_keep_fit = np . array ( results ["loc_keep_fit" ])
156193
157194 # Create a new DataFrame containing the PCA coords data.
158195 df_pca = pd .DataFrame (coords , index = samples )
@@ -205,6 +242,7 @@ def _pca(
205242 random_seed ,
206243 chunks ,
207244 inline_array ,
245+ ** kwargs ,
208246 ):
209247 # Load diplotypes.
210248 ds_diplotypes = self .biallelic_diplotypes (
0 commit comments