Skip to content

Commit 2dfa8f8

Browse files
committed
Add cohort downsampling support to PCA and update tests
- Introduced `cohorts`, `cohort_size`, `min_cohort_size`, and `max_cohort_size` parameters in the PCA method. - Updated PCA docstring to reflect new parameters. - Added example usage for cohort downsampling in the notebook. - Implemented tests for cohort downsampling functionality, including validation of parameter combinations.
1 parent a39a3c1 commit 2dfa8f8

3 files changed

Lines changed: 168 additions & 8 deletions

File tree

malariagen_data/anoph/pca.py

Lines changed: 60 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,28 @@ def __init__(
4343
`site_class`, `cohort_size`, `min_cohort_size`, `max_cohort_size`,
4444
`random_seed`.
4545
46+
.. versionchanged:: 9.0.0
47+
The `cohorts` parameter has been added to enable cohort-based
48+
downsampling via the `max_cohort_size` parameter.
4649
""",
4750
returns=("df_pca", "evr"),
4851
notes="""
4952
This computation may take some time to run, depending on your computing
5053
environment. Results of this computation will be cached and re-used if
5154
the `results_cache` parameter was set when instantiating the API client.
5255
""",
56+
examples="""
57+
Run a PCA, downsampling to a maximum of 20 samples per country::
58+
59+
>>> import malariagen_data
60+
>>> ag3 = malariagen_data.Ag3()
61+
>>> df_pca, evr = ag3.pca(
62+
... region="3R",
63+
... n_snps=1000,
64+
... cohorts="country",
65+
... max_cohort_size=20,
66+
... )
67+
""",
5368
)
5469
def pca(
5570
self,
@@ -61,6 +76,10 @@ def pca(
6176
sample_query: Optional[base_params.sample_query] = None,
6277
sample_query_options: Optional[base_params.sample_query_options] = None,
6378
sample_indices: Optional[base_params.sample_indices] = None,
79+
cohorts: Optional[base_params.cohorts] = None,
80+
cohort_size: Optional[base_params.cohort_size] = None,
81+
min_cohort_size: Optional[base_params.min_cohort_size] = None,
82+
max_cohort_size: Optional[base_params.max_cohort_size] = None,
6483
site_mask: Optional[base_params.site_mask] = base_params.DEFAULT,
6584
site_class: Optional[base_params.site_class] = None,
6685
min_minor_ac: Optional[
@@ -69,9 +88,6 @@ def pca(
6988
max_missing_an: Optional[
7089
base_params.max_missing_an
7190
] = pca_params.max_missing_an_default,
72-
cohort_size: Optional[base_params.cohort_size] = None,
73-
min_cohort_size: Optional[base_params.min_cohort_size] = None,
74-
max_cohort_size: Optional[base_params.max_cohort_size] = None,
7591
exclude_samples: Optional[base_params.samples] = None,
7692
fit_exclude_samples: Optional[base_params.samples] = None,
7793
random_seed: base_params.random_seed = 42,
@@ -82,6 +98,41 @@ def pca(
8298
# invalidate any previously cached data.
8399
name = "pca_v4"
84100

101+
# Handle cohort downsampling.
102+
if cohorts is not None:
103+
if max_cohort_size is None:
104+
raise ValueError("`max_cohort_size` is required when `cohorts` is provided.")
105+
if sample_indices is not None:
106+
raise ValueError(
107+
"Cannot use `sample_indices` with `cohorts` and `max_cohort_size`."
108+
)
109+
if cohort_size is not None or min_cohort_size is not None:
110+
raise ValueError(
111+
"Cannot use `cohort_size` or `min_cohort_size` with `cohorts`."
112+
)
113+
df_samples = self.sample_metadata(
114+
sample_sets=sample_sets,
115+
sample_query=sample_query,
116+
sample_query_options=sample_query_options,
117+
)
118+
# N.B., we are going to overwrite the sample_indices parameter here.
119+
groups = df_samples.groupby(cohorts, sort=False)
120+
ix = []
121+
for _, group in groups:
122+
if len(group) > max_cohort_size:
123+
ix.extend(
124+
group.sample(
125+
n=max_cohort_size, random_state=random_seed, replace=False
126+
).index
127+
)
128+
else:
129+
ix.extend(group.index)
130+
sample_indices = ix
131+
# From this point onwards, the sample_query is no longer needed, because
132+
# the sample selection is defined by the sample_indices.
133+
sample_query = None
134+
sample_query_options = None
135+
85136
# Normalize params for consistent hash value.
86137
(
87138
sample_sets_prepped,
@@ -105,6 +156,7 @@ def pca(
105156
min_minor_ac=min_minor_ac,
106157
max_missing_an=max_missing_an,
107158
n_components=n_components,
159+
cohorts=cohorts,
108160
cohort_size=cohort_size,
109161
min_cohort_size=min_cohort_size,
110162
max_cohort_size=max_cohort_size,
@@ -122,10 +174,10 @@ def pca(
122174
self.results_cache_set(name=name, params=params, results=results)
123175

124176
# Unpack results.
125-
coords = results["coords"]
126-
evr = results["evr"]
127-
samples = results["samples"]
128-
loc_keep_fit = results["loc_keep_fit"]
177+
coords = np.array(results["coords"])
178+
evr = np.array(results["evr"])
179+
samples = np.array(results["samples"])
180+
loc_keep_fit = np.array(results["loc_keep_fit"])
129181

130182
# Load sample metadata.
131183
df_samples = self.sample_metadata(
@@ -166,6 +218,7 @@ def _pca(
166218
random_seed,
167219
chunks,
168220
inline_array,
221+
**kwargs,
169222
):
170223
# Load diplotypes.
171224
gn, samples = self.biallelic_diplotypes(

notebooks/plot_pca.ipynb

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -620,10 +620,38 @@
620620
")"
621621
]
622622
},
623+
{
624+
"cell_type": "markdown",
625+
"id": "f1e8c954",
626+
"metadata": {},
627+
"source": [
628+
"## PCA with cohort downsampling"
629+
]
630+
},
631+
{
632+
"cell_type": "code",
633+
"execution_count": null,
634+
"id": "e4a484f3",
635+
"metadata": {},
636+
"outputs": [],
637+
"source": [
638+
"df_pca_cohorts, evr_cohorts = ag3.pca(\n",
639+
" region=\"3L:15,000,000-16,000,000\",\n",
640+
" sample_sets=\"3.0\",\n",
641+
" n_snps=10_000,\n",
642+
" cohorts=\"country\",\n",
643+
" max_cohort_size=20,\n",
644+
")\n",
645+
"ag3.plot_pca_coords(\n",
646+
" df_pca_cohorts,\n",
647+
" color=\"country\",\n",
648+
")"
649+
]
650+
},
623651
{
624652
"cell_type": "code",
625653
"execution_count": null,
626-
"id": "33d788a2-f256-4930-b1e5-b4f31e681a36",
654+
"id": "abb2ee83",
627655
"metadata": {},
628656
"outputs": [],
629657
"source": []

tests/anoph/test_pca.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,3 +288,82 @@ def test_pca_fit_exclude_samples(fixture, api: AnophelesPca):
288288
len(pca_df.query(f"sample_id in {exclude_samples} and not pca_fit"))
289289
== n_samples_excluded
290290
)
291+
292+
293+
@parametrize_with_cases("fixture,api", cases=".")
294+
def test_pca_cohort_downsampling(fixture, api: AnophelesPca):
295+
# Parameters for selecting input data.
296+
all_sample_sets = api.sample_sets()["sample_set"].to_list()
297+
sample_sets = random.sample(all_sample_sets, 2)
298+
data_params = dict(
299+
region=random.choice(api.contigs),
300+
sample_sets=sample_sets,
301+
site_mask=random.choice((None,) + api.site_mask_ids),
302+
)
303+
304+
# Test cohort downsampling.
305+
cohort_col = "country"
306+
max_cohort_size = 10
307+
random_seed = 42
308+
309+
# Try to run the PCA with cohort downsampling.
310+
try:
311+
pca_df, pca_evr = api.pca(
312+
n_snps=100, # Use a small number to avoid "Not enough SNPs" errors
313+
n_components=2,
314+
cohorts=cohort_col,
315+
max_cohort_size=max_cohort_size,
316+
random_seed=random_seed,
317+
**data_params,
318+
)
319+
except ValueError as e:
320+
if "Not enough SNPs" in str(e):
321+
pytest.skip("Not enough SNPs available after downsampling to run test.")
322+
else:
323+
raise
324+
325+
# Check types.
326+
assert isinstance(pca_df, pd.DataFrame)
327+
assert isinstance(pca_evr, np.ndarray)
328+
329+
# Check basic structure.
330+
assert len(pca_df) > 0
331+
assert "PC1" in pca_df.columns
332+
assert "PC2" in pca_df.columns
333+
assert "pca_fit" in pca_df.columns
334+
assert pca_df["pca_fit"].all()
335+
assert pca_evr.ndim == 1
336+
assert pca_evr.shape[0] == 2
337+
338+
# Check cohort counts.
339+
final_cohort_counts = pca_df[cohort_col].value_counts()
340+
for cohort, count in final_cohort_counts.items():
341+
assert count <= max_cohort_size
342+
343+
# Test bad parameter combinations.
344+
with pytest.raises(ValueError):
345+
api.pca(
346+
n_snps=100,
347+
n_components=2,
348+
cohorts=cohort_col,
349+
# max_cohort_size is missing
350+
**data_params,
351+
)
352+
with pytest.raises(ValueError):
353+
api.pca(
354+
n_snps=100,
355+
n_components=2,
356+
cohorts=cohort_col,
357+
max_cohort_size=max_cohort_size,
358+
sample_indices=[0, 1, 2],
359+
**data_params,
360+
)
361+
with pytest.raises(ValueError):
362+
api.pca(
363+
n_snps=100,
364+
n_components=2,
365+
cohorts=cohort_col,
366+
max_cohort_size=max_cohort_size,
367+
cohort_size=10,
368+
**data_params,
369+
)

0 commit comments

Comments
 (0)