Skip to content

Commit ad8b35d

Browse files
authored
Merge branch 'master' into fix/issue-1280-vcf-performance
2 parents f63ebc1 + 06563cf commit ad8b35d

File tree

3 files changed

+156
-11
lines changed

3 files changed

+156
-11
lines changed

malariagen_data/anoph/pca.py

Lines changed: 48 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ def __init__(
4242
The following additional parameters were also added in version 8.0.0:
4343
`site_class`, `cohort_size`, `min_cohort_size`, `max_cohort_size`,
4444
`random_seed`.
45-
4645
""",
4746
parameters=dict(
4847
imputation_method="""
@@ -69,6 +68,10 @@ def pca(
6968
sample_query: Optional[base_params.sample_query] = None,
7069
sample_query_options: Optional[base_params.sample_query_options] = None,
7170
sample_indices: Optional[base_params.sample_indices] = None,
71+
cohorts: Optional[base_params.cohorts] = None,
72+
cohort_size: Optional[base_params.cohort_size] = None,
73+
min_cohort_size: Optional[base_params.min_cohort_size] = None,
74+
max_cohort_size: Optional[base_params.max_cohort_size] = None,
7275
site_mask: Optional[base_params.site_mask] = base_params.DEFAULT,
7376
site_class: Optional[base_params.site_class] = None,
7477
min_minor_ac: Optional[
@@ -78,9 +81,6 @@ def pca(
7881
base_params.max_missing_an
7982
] = pca_params.max_missing_an_default,
8083
imputation_method: pca_params.imputation_method = pca_params.imputation_method_default,
81-
cohort_size: Optional[base_params.cohort_size] = None,
82-
min_cohort_size: Optional[base_params.min_cohort_size] = None,
83-
max_cohort_size: Optional[base_params.max_cohort_size] = None,
8484
exclude_samples: Optional[base_params.samples] = None,
8585
fit_exclude_samples: Optional[base_params.samples] = None,
8686
random_seed: base_params.random_seed = 42,
@@ -98,8 +98,44 @@ def pca(
9898

9999
## Normalize params for consistent hash value.
100100

101-
# Note: `_prep_sample_selection_cache_params` converts `sample_query` and `sample_query_options` into `sample_indices`.
102-
# So `sample_query` and `sample_query_options` should not be used beyond this point. (`sample_indices` should be used instead.)
101+
# Handle cohort downsampling.
102+
if cohorts is not None:
103+
if max_cohort_size is None:
104+
raise ValueError(
105+
"`max_cohort_size` is required when `cohorts` is provided."
106+
)
107+
if sample_indices is not None:
108+
raise ValueError(
109+
"Cannot use `sample_indices` with `cohorts` and `max_cohort_size`."
110+
)
111+
if cohort_size is not None or min_cohort_size is not None:
112+
raise ValueError(
113+
"Cannot use `cohort_size` or `min_cohort_size` with `cohorts`."
114+
)
115+
df_samples = self.sample_metadata(
116+
sample_sets=sample_sets,
117+
sample_query=sample_query,
118+
sample_query_options=sample_query_options,
119+
)
120+
# N.B., we are going to overwrite the sample_indices parameter here.
121+
groups = df_samples.groupby(cohorts, sort=False)
122+
ix = []
123+
for _, group in groups:
124+
if len(group) > max_cohort_size:
125+
ix.extend(
126+
group.sample(
127+
n=max_cohort_size, random_state=random_seed, replace=False
128+
).index
129+
)
130+
else:
131+
ix.extend(group.index)
132+
sample_indices = ix
133+
# From this point onwards, the sample_query is no longer needed, because
134+
# the sample selection is defined by the sample_indices.
135+
sample_query = None
136+
sample_query_options = None
137+
138+
# Normalize params for consistent hash value.
103139
(
104140
prepared_sample_sets,
105141
prepared_sample_indices,
@@ -132,6 +168,7 @@ def pca(
132168
max_missing_an=max_missing_an,
133169
imputation_method=imputation_method,
134170
n_components=n_components,
171+
cohorts=cohorts,
135172
cohort_size=cohort_size,
136173
min_cohort_size=min_cohort_size,
137174
max_cohort_size=max_cohort_size,
@@ -149,10 +186,10 @@ def pca(
149186
self.results_cache_set(name=name, params=params, results=results)
150187

151188
# Unpack results.
152-
coords = results["coords"]
153-
evr = results["evr"]
154-
samples = results["samples"]
155-
loc_keep_fit = results["loc_keep_fit"]
189+
coords = np.array(results["coords"])
190+
evr = np.array(results["evr"])
191+
samples = np.array(results["samples"])
192+
loc_keep_fit = np.array(results["loc_keep_fit"])
156193

157194
# Create a new DataFrame containing the PCA coords data.
158195
df_pca = pd.DataFrame(coords, index=samples)
@@ -205,6 +242,7 @@ def _pca(
205242
random_seed,
206243
chunks,
207244
inline_array,
245+
**kwargs,
208246
):
209247
# Load diplotypes.
210248
ds_diplotypes = self.biallelic_diplotypes(

notebooks/plot_pca.ipynb

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -620,10 +620,38 @@
620620
")"
621621
]
622622
},
623+
{
624+
"cell_type": "markdown",
625+
"id": "f1e8c954",
626+
"metadata": {},
627+
"source": [
628+
"## PCA with cohort downsampling"
629+
]
630+
},
631+
{
632+
"cell_type": "code",
633+
"execution_count": null,
634+
"id": "e4a484f3",
635+
"metadata": {},
636+
"outputs": [],
637+
"source": [
638+
"df_pca_cohorts, evr_cohorts = ag3.pca(\n",
639+
" region=\"3L:15,000,000-16,000,000\",\n",
640+
" sample_sets=\"3.0\",\n",
641+
" n_snps=10_000,\n",
642+
" cohorts=\"country\",\n",
643+
" max_cohort_size=20,\n",
644+
")\n",
645+
"ag3.plot_pca_coords(\n",
646+
" df_pca_cohorts,\n",
647+
" color=\"country\",\n",
648+
")"
649+
]
650+
},
623651
{
624652
"cell_type": "code",
625653
"execution_count": null,
626-
"id": "33d788a2-f256-4930-b1e5-b4f31e681a36",
654+
"id": "abb2ee83",
627655
"metadata": {},
628656
"outputs": [],
629657
"source": []

tests/anoph/test_pca.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,85 @@ def test_pca_fit_exclude_samples(fixture, api: AnophelesPca):
340340
)
341341

342342

343+
@parametrize_with_cases("fixture,api", cases=".")
344+
def test_pca_cohort_downsampling(fixture, api: AnophelesPca):
345+
# Parameters for selecting input data.
346+
all_sample_sets = api.sample_sets()["sample_set"].to_list()
347+
sample_sets = np.random.choice(all_sample_sets, size=2, replace=False).tolist()
348+
data_params = dict(
349+
region=str(np.random.choice(api.contigs)),
350+
sample_sets=sample_sets,
351+
site_mask=np.random.choice(list(api.site_mask_ids) + [None]),
352+
)
353+
354+
# Test cohort downsampling.
355+
cohort_col = "country"
356+
max_cohort_size = 10
357+
random_seed = 42
358+
359+
# Try to run the PCA with cohort downsampling.
360+
try:
361+
pca_df, pca_evr = api.pca(
362+
n_snps=100, # Use a small number to avoid "Not enough SNPs" errors
363+
n_components=2,
364+
cohorts=cohort_col,
365+
max_cohort_size=max_cohort_size,
366+
random_seed=random_seed,
367+
**data_params,
368+
)
369+
except ValueError as e:
370+
if "Not enough SNPs" in str(e):
371+
pytest.skip("Not enough SNPs available after downsampling to run test.")
372+
else:
373+
raise
374+
375+
# Check types.
376+
assert isinstance(pca_df, pd.DataFrame)
377+
assert isinstance(pca_evr, np.ndarray)
378+
379+
# Check basic structure.
380+
assert len(pca_df) > 0
381+
assert "PC1" in pca_df.columns
382+
assert "PC2" in pca_df.columns
383+
assert "pca_fit" in pca_df.columns
384+
assert pca_df["pca_fit"].all()
385+
assert pca_evr.ndim == 1
386+
assert pca_evr.shape[0] == 2
387+
388+
# Check cohort counts.
389+
final_cohort_counts = pca_df[cohort_col].value_counts()
390+
for cohort, count in final_cohort_counts.items():
391+
assert count <= max_cohort_size
392+
393+
# Test bad parameter combinations.
394+
with pytest.raises(ValueError):
395+
api.pca(
396+
n_snps=100,
397+
n_components=2,
398+
cohorts=cohort_col,
399+
# max_cohort_size is missing
400+
**data_params,
401+
)
402+
with pytest.raises(ValueError):
403+
api.pca(
404+
n_snps=100,
405+
n_components=2,
406+
cohorts=cohort_col,
407+
max_cohort_size=max_cohort_size,
408+
sample_indices=[0, 1, 2],
409+
**data_params,
410+
)
411+
with pytest.raises(ValueError):
412+
api.pca(
413+
n_snps=100,
414+
n_components=2,
415+
cohorts=cohort_col,
416+
max_cohort_size=max_cohort_size,
417+
cohort_size=10,
418+
**data_params,
419+
)
420+
421+
343422
# --- _jitter() determinism unit tests ---
344423

345424

0 commit comments

Comments
 (0)