Skip to content

Commit bee0271

Browse files
fix: correct dask array extraction in _cohort_count_het_vectorized
- Store raw dask array before subsetting to avoid AttributeError on .data - Access gt_data directly instead of wrapping then slicing GenotypeDaskArray - All 28 tests pass (4 cohort_heterozygosity + 4 regression + 20 others) - Maintains memory optimization: per-sample computation avoids materializing full array - Addresses final Copilot code review suggestion
1 parent e50eca7 commit bee0271

2 files changed

Lines changed: 16 additions & 12 deletions

File tree

malariagen_data/anoph/heterozygosity.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -439,7 +439,6 @@ def _cohort_count_het_vectorized(
439439

440440
# Extract sample IDs from cohort dataframe
441441
sample_ids = df_cohort_samples["sample_id"].values
442-
sample_id_to_idx = {sid: idx for idx, sid in enumerate(sample_ids)}
443442

444443
debug("access SNPs for all cohort samples")
445444
# Load SNP data once for all samples in cohort
@@ -451,6 +450,10 @@ def _cohort_count_het_vectorized(
451450
inline_array=inline_array,
452451
)
453452

453+
# Subset to cohort samples to ensure correct indexing
454+
ds_snps = ds_snps.set_index(samples="sample_id").sel(samples=sample_ids)
455+
sample_id_to_idx = {sid: idx for idx, sid in enumerate(sample_ids)}
456+
454457
# SNP positions (same for all samples)
455458
pos = ds_snps["variant_position"].values
456459

@@ -470,18 +473,17 @@ def _cohort_count_het_vectorized(
470473
)
471474

472475
# access genotypes for all samples
473-
gt = allel.GenotypeDaskArray(ds_snps["call_genotype"].data)
474-
475-
# compute het across all samples: shape (variants, samples)
476-
debug("Compute heterozygous genotypes for all samples")
477-
with self._dask_progress(desc="Compute heterozygous genotypes"):
478-
is_het_all = gt.is_het().compute()
476+
gt_data = ds_snps["call_genotype"].data
479477

480478
# Compute windowed heterozygosity for each sample and cache results
481479
results = {}
482480
for sample_id, sample_idx in sample_id_to_idx.items():
483-
# Extract heterozygosity column for this sample
484-
is_het_sample = is_het_all[:, sample_idx]
481+
# Compute heterozygous genotypes for this sample only to avoid
482+
# materializing the full (variants, samples) array in memory.
483+
debug(f"Compute heterozygous genotypes for sample {sample_id}")
484+
gt_sample = allel.GenotypeDaskVector(gt_data[:, sample_idx, :])
485+
with self._dask_progress(desc="Compute heterozygous genotypes"):
486+
is_het_sample = gt_sample.is_het().compute()
485487

486488
# compute windowed heterozygosity for this sample
487489
counts = allel.moving_statistic(
@@ -910,7 +912,7 @@ def cohort_heterozygosity(
910912
# Compute per-sample means and aggregate.
911913
het_values = []
912914
for sample_id in df_cohort_samples["sample_id"]:
913-
windows, counts = cohort_het_results[sample_id]
915+
_, counts = cohort_het_results[sample_id]
914916
het_mean = np.mean(counts / window_size)
915917
het_values.append(het_mean)
916918

tests/anoph/test_heterozygosity.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -281,8 +281,10 @@ def test_cohort_count_het_vectorized_regression(fixture, api: AnophelesHetAnalys
281281

282282
# Get sample metadata for a small cohort
283283
df_samples = api.sample_metadata(sample_sets=sample_set)
284-
# Use first few samples to keep test fast
285-
df_cohort_samples = df_samples.head(min(3, len(df_samples))).reset_index(drop=True)
284+
# Use a small, non-trivial subset of samples (fixed random_state for reproducibility)
285+
df_cohort_samples = df_samples.sample(
286+
n=min(3, len(df_samples)), random_state=0
287+
).reset_index(drop=True)
286288

287289
# Parse region once
288290
region_prepped = _parse_single_region(api, region)

0 commit comments

Comments
 (0)