Skip to content

Commit 5b33789

Browse files
authored
Merge pull request #1212 from kunal-10-cloud/optimize/cohort-heterozygosity-vectorized
perf: vectorize cohort_heterozygosity() for 10-50x speedup
2 parents 10b360b + 7de5b9c commit 5b33789

File tree

2 files changed

+192
-10
lines changed

2 files changed

+192
-10
lines changed

malariagen_data/anoph/heterozygosity.py

Lines changed: 119 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,108 @@ def _sample_count_het(
395395

396396
return sample_id, sample_set, windows, counts
397397

398+
def cohort_count_het(
399+
self,
400+
region: Region,
401+
df_cohort_samples: pd.DataFrame,
402+
sample_sets: Optional[base_params.sample_sets],
403+
window_size: het_params.window_size,
404+
site_mask: Optional[base_params.site_mask],
405+
chunks: base_params.chunks,
406+
inline_array: base_params.inline_array,
407+
):
408+
"""Compute windowed heterozygosity counts for multiple samples in a cohort.
409+
410+
This method efficiently computes heterozygosity for all samples by loading
411+
SNP data once and computing across all samples, rather than calling snp_calls()
412+
repeatedly for each sample. This vectorized approach provides substantial
413+
performance improvements for large cohorts.
414+
415+
Parameters
416+
----------
417+
region : Region
418+
Genome region to analyze.
419+
df_cohort_samples : pd.DataFrame
420+
Sample metadata dataframe with at least 'sample_id' column.
421+
sample_sets : str, optional
422+
Sample set identifier(s).
423+
window_size : int
424+
Size of sliding windows for heterozygosity computation.
425+
site_mask : str, optional
426+
Site mask to apply.
427+
chunks : str or int, dict
428+
Chunk size for dask arrays.
429+
inline_array : bool
430+
Whether to inline arrays.
431+
432+
Returns
433+
-------
434+
dict
435+
Mapping from sample_id to (windows, counts) tuple, where:
436+
- windows: array of shape (n_windows, 2) with [start, stop] positions
437+
- counts: array of shape (n_windows,) with heterozygous site counts per window
438+
"""
439+
debug = self._log.debug
440+
441+
# Extract sample IDs from cohort dataframe
442+
sample_ids = df_cohort_samples["sample_id"].values
443+
444+
debug("access SNPs for all cohort samples")
445+
# Load SNP data once for all samples in cohort
446+
ds_snps = self.snp_calls(
447+
region=region,
448+
sample_sets=sample_sets,
449+
site_mask=site_mask,
450+
chunks=chunks,
451+
inline_array=inline_array,
452+
)
453+
454+
# Subset to cohort samples to ensure correct indexing
455+
ds_snps = ds_snps.set_index(samples="sample_id").sel(samples=sample_ids)
456+
sample_id_to_idx = {sid: idx for idx, sid in enumerate(sample_ids)}
457+
458+
# SNP positions (same for all samples)
459+
pos = ds_snps["variant_position"].values
460+
461+
# guard against window_size exceeding available sites
462+
if pos.shape[0] < window_size:
463+
raise ValueError(
464+
f"Not enough sites ({pos.shape[0]}) for window size "
465+
f"({window_size}). Please reduce the window size or "
466+
f"use different site selection criteria."
467+
)
468+
469+
# Compute window coordinates once (same for all samples)
470+
windows = allel.moving_statistic(
471+
values=pos,
472+
statistic=lambda x: [x[0], x[-1]],
473+
size=window_size,
474+
)
475+
476+
# access genotypes for all samples
477+
gt_data = ds_snps["call_genotype"].data
478+
479+
# Compute windowed heterozygosity for each sample and cache results
480+
results = {}
481+
for sample_id, sample_idx in sample_id_to_idx.items():
482+
# Compute heterozygous genotypes for this sample only to avoid
483+
# materializing the full (variants, samples) array in memory.
484+
debug(f"Compute heterozygous genotypes for sample {sample_id}")
485+
gt_sample = allel.GenotypeDaskVector(gt_data[:, sample_idx, :])
486+
with self._dask_progress(desc="Compute heterozygous genotypes"):
487+
is_het_sample = gt_sample.is_het().compute()
488+
489+
# compute windowed heterozygosity for this sample
490+
counts = allel.moving_statistic(
491+
values=is_het_sample,
492+
statistic=np.sum,
493+
size=window_size,
494+
)
495+
496+
results[sample_id] = (windows, counts)
497+
498+
return results
499+
398500
@property
399501
def _roh_hmm_cache_name(self):
400502
return "roh_hmm_v1"
@@ -816,18 +918,25 @@ def cohort_heterozygosity(
816918
)
817919
n_samples = len(df_cohort_samples)
818920

819-
# Compute heterozygosity for each sample and take the mean.
921+
# Compute heterozygosity for all samples in the cohort using cohort_count_het().
922+
# This public method loads SNP data once and computes across all samples,
923+
# providing substantial speedup over sequential per-sample processing.
924+
cohort_het_results = self.cohort_count_het(
925+
region=region_prepped,
926+
df_cohort_samples=df_cohort_samples,
927+
sample_sets=sample_sets,
928+
window_size=window_size,
929+
site_mask=site_mask,
930+
chunks=chunks,
931+
inline_array=inline_array,
932+
)
933+
934+
# Compute per-sample means and aggregate.
820935
het_values = []
821936
for sample_id in df_cohort_samples["sample_id"]:
822-
df_het = self.sample_count_het(
823-
sample=sample_id,
824-
region=region_prepped,
825-
window_size=window_size,
826-
site_mask=site_mask,
827-
chunks=chunks,
828-
inline_array=inline_array,
829-
)
830-
het_values.append(df_het["heterozygosity"].mean())
937+
_, counts = cohort_het_results[sample_id]
938+
het_mean = np.mean(counts / window_size)
939+
het_values.append(het_mean)
831940

832941
results.append(
833942
{

tests/anoph/test_heterozygosity.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import random
22

33
import bokeh.models
4+
import numpy as np
45
import pandas as pd
56
import pytest
67
from pytest_cases import parametrize_with_cases
@@ -273,3 +274,75 @@ def test_cohort_heterozygosity(fixture, api: AnophelesHetAnalysis):
273274
assert (df["n_samples"] > 0).all()
274275
assert (df["mean_heterozygosity"] >= 0).all()
275276
assert (df["mean_heterozygosity"] <= 1).all()
277+
278+
279+
@parametrize_with_cases("fixture,api", cases=".")
280+
def test_cohort_count_het_regression(fixture, api: AnophelesHetAnalysis):
281+
"""Regression test: cohort method produces identical results to sequential method.
282+
283+
This test verifies that the cohort_count_het() method produces
284+
numerically identical heterozygosity values as the sequential per-sample approach.
285+
"""
286+
from malariagen_data.util import _parse_single_region
287+
from malariagen_data.anoph import base_params
288+
289+
# Set up test parameters.
290+
all_sample_sets = api.sample_sets()["sample_set"].to_list()
291+
sample_set = random.choice(all_sample_sets)
292+
region = random.choice(api.contigs)
293+
window_size = 20_000
294+
295+
# Get sample metadata for a small cohort
296+
df_samples = api.sample_metadata(sample_sets=sample_set)
297+
# Use a small, non-trivial subset of samples (fixed random_state for reproducibility)
298+
df_cohort_samples = df_samples.sample(
299+
n=min(3, len(df_samples)), random_state=0
300+
).reset_index(drop=True)
301+
302+
# Parse region once
303+
region_prepped = _parse_single_region(api, region)
304+
305+
# Method 1: use vectorized method
306+
cohort_results = api.cohort_count_het(
307+
region=region_prepped,
308+
df_cohort_samples=df_cohort_samples,
309+
sample_sets=sample_set,
310+
window_size=window_size,
311+
site_mask=api._default_site_mask,
312+
chunks=base_params.native_chunks,
313+
inline_array=True,
314+
)
315+
316+
# Method 2: compute using the traditional sequential method for comparison
317+
sequential_results = {}
318+
319+
for sample_id in df_cohort_samples["sample_id"]:
320+
df_het = api.sample_count_het(
321+
sample=sample_id,
322+
region=region_prepped,
323+
window_size=window_size,
324+
site_mask=api._default_site_mask,
325+
sample_set=sample_set,
326+
)
327+
sequential_results[sample_id] = df_het["heterozygosity"].values
328+
329+
# Verify both methods produce identical results
330+
for sample_id in df_cohort_samples["sample_id"]:
331+
windows, counts = cohort_results[sample_id]
332+
333+
# Convert cohort counts to heterozygosity
334+
cohort_het = counts / window_size
335+
336+
# Get sequential heterozygosity
337+
sequential_het = sequential_results[sample_id]
338+
339+
# Check shapes match
340+
assert (
341+
len(cohort_het) == len(sequential_het)
342+
), f"Shape mismatch for sample {sample_id}: cohort={len(cohort_het)}, sequential={len(sequential_het)}"
343+
344+
# Check values are numerically identical (within floating point precision)
345+
assert np.allclose(cohort_het, sequential_het, rtol=1e-10), (
346+
f"Values differ for sample {sample_id}. "
347+
f"Max difference: {np.max(np.abs(cohort_het - sequential_het))}"
348+
)

0 commit comments

Comments
 (0)