Skip to content

Commit c300a3d

Browse files
committed
Merge branch 'master' of https://github.com/Yashsingh045/malariagen-data-python into GH1221-snp-data-types
2 parents 886c620 + 8d607c2 commit c300a3d

8 files changed

Lines changed: 409 additions & 22 deletions

File tree

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ To get setup for development, see [this video if you prefer VS Code](https://you
4949
For detailed setup instructions, see:
5050
- [Linux setup guide](LINUX_SETUP.md)
5151
- [macOS setup guide](MACOS_SETUP.md)
52+
- [Windows setup guide](WINDOWS_SETUP.md)
5253
- [Google Colab (TPU) setup guide](docs/source/colab_tpu_runtime.rst)
5354
Detailed instructions can be found in the [Contributors guide](https://github.com/malariagen/malariagen-data-python/blob/master/CONTRIBUTING.md).
5455

WINDOWS_SETUP.md

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# Windows Setup Guide
2+
3+
To get setup for development on Windows, see
4+
[this video if you prefer VS Code](https://youtu.be/zddl3n1DCFM),
5+
or [this older video if you prefer PyCharm](https://youtu.be/QniQi-Hoo9A),
6+
and the instructions below.
7+
8+
## 1. Fork and clone this repo
9+
```bash
10+
git clone https://github.com/[username]/malariagen-data-python.git
11+
cd malariagen-data-python
12+
```
13+
14+
## 2. Install Python
15+
16+
Download and install Python 3.10 from the official website:
17+
https://www.python.org/downloads/windows/
18+
19+
During installation, check the box that says Add Python to PATH
20+
before clicking Install.
21+
22+
Verify the installation worked:
23+
```bash
24+
python --version
25+
```
26+
27+
## 3. Install pipx and poetry
28+
```bash
29+
python -m pip install --user pipx
30+
python -m pipx ensurepath
31+
pipx install poetry
32+
```
33+
34+
After running ensurepath, close and reopen PowerShell before continuing.
35+
36+
## 4. Create and activate development environment
37+
```bash
38+
poetry install
39+
poetry shell
40+
```
41+
42+
## 5. Install pre-commit hooks
43+
```bash
44+
pipx install pre-commit
45+
pre-commit install
46+
```
47+
48+
## 6. Add upstream remote and get latest code
49+
```bash
50+
git remote add upstream https://github.com/malariagen/malariagen-data-python
51+
git pull upstream master
52+
```
53+
54+
Note: On Windows the default branch is called master, not main.
55+
56+
## 7. Verify everything works
57+
```bash
58+
python -c "import malariagen_data; print('Setup successful!')"
59+
```
60+
61+
## Common Issues on Windows
62+
63+
**poetry not found after install**
64+
65+
Close and reopen PowerShell, then try again.
66+
67+
**git not recognized**
68+
69+
Install Git from https://git-scm.com/download/win
70+
and restart PowerShell.
71+
72+
**python not recognized**
73+
74+
Reinstall Python and make sure to check
75+
Add Python to PATH during installation.
76+
77+
**fatal: not a git repository**
78+
79+
Make sure you are inside the malariagen-data-python
80+
folder before running any git commands.
81+
```bash
82+
cd malariagen-data-python
83+
```
84+
85+
**error: pathspec main did not match**
86+
87+
On Windows use master instead of main.
88+
```bash
89+
git checkout master
90+
```

malariagen_data/anoph/base.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import warnings
23

34
import json
45
from contextlib import nullcontext
@@ -496,7 +497,20 @@ def client_location(self) -> str:
496497
return location
497498

498499
def _surveillance_flags(self, sample_sets: List[str]):
499-
raise NotImplementedError("Subclasses must implement `_surveillance_flags`.")
500+
"""Return surveillance flags for sample sets. Subclasses should override to
501+
load real data; this base implementation returns empty data and warns.
502+
"""
503+
warnings.warn(
504+
"Surveillance flags not implemented for this resource; returning empty data.",
505+
UserWarning,
506+
stacklevel=2,
507+
)
508+
return pd.DataFrame(
509+
{
510+
"sample_id": pd.Series(dtype="object"),
511+
"is_surveillance": pd.Series(dtype="boolean"),
512+
}
513+
)
500514

501515
def _release_has_unrestricted_data(self, *, release: str):
502516
"""Return `True` if the specified release has any unrestricted data. Otherwise return `False`."""

malariagen_data/anoph/heterozygosity.py

Lines changed: 119 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,108 @@ def _sample_count_het(
395395

396396
return sample_id, sample_set, windows, counts
397397

398+
def cohort_count_het(
399+
self,
400+
region: Region,
401+
df_cohort_samples: pd.DataFrame,
402+
sample_sets: Optional[base_params.sample_sets],
403+
window_size: het_params.window_size,
404+
site_mask: Optional[base_params.site_mask],
405+
chunks: base_params.chunks,
406+
inline_array: base_params.inline_array,
407+
):
408+
"""Compute windowed heterozygosity counts for multiple samples in a cohort.
409+
410+
This method efficiently computes heterozygosity for all samples by loading
411+
SNP data once and computing across all samples, rather than calling snp_calls()
412+
repeatedly for each sample. This vectorized approach provides substantial
413+
performance improvements for large cohorts.
414+
415+
Parameters
416+
----------
417+
region : Region
418+
Genome region to analyze.
419+
df_cohort_samples : pd.DataFrame
420+
Sample metadata dataframe with at least 'sample_id' column.
421+
sample_sets : str, optional
422+
Sample set identifier(s).
423+
window_size : int
424+
Size of sliding windows for heterozygosity computation.
425+
site_mask : str, optional
426+
Site mask to apply.
427+
chunks : str or int, dict
428+
Chunk size for dask arrays.
429+
inline_array : bool
430+
Whether to inline arrays.
431+
432+
Returns
433+
-------
434+
dict
435+
Mapping from sample_id to (windows, counts) tuple, where:
436+
- windows: array of shape (n_windows, 2) with [start, stop] positions
437+
- counts: array of shape (n_windows,) with heterozygous site counts per window
438+
"""
439+
debug = self._log.debug
440+
441+
# Extract sample IDs from cohort dataframe
442+
sample_ids = df_cohort_samples["sample_id"].values
443+
444+
debug("access SNPs for all cohort samples")
445+
# Load SNP data once for all samples in cohort
446+
ds_snps = self.snp_calls(
447+
region=region,
448+
sample_sets=sample_sets,
449+
site_mask=site_mask,
450+
chunks=chunks,
451+
inline_array=inline_array,
452+
)
453+
454+
# Subset to cohort samples to ensure correct indexing
455+
ds_snps = ds_snps.set_index(samples="sample_id").sel(samples=sample_ids)
456+
sample_id_to_idx = {sid: idx for idx, sid in enumerate(sample_ids)}
457+
458+
# SNP positions (same for all samples)
459+
pos = ds_snps["variant_position"].values
460+
461+
# guard against window_size exceeding available sites
462+
if pos.shape[0] < window_size:
463+
raise ValueError(
464+
f"Not enough sites ({pos.shape[0]}) for window size "
465+
f"({window_size}). Please reduce the window size or "
466+
f"use different site selection criteria."
467+
)
468+
469+
# Compute window coordinates once (same for all samples)
470+
windows = allel.moving_statistic(
471+
values=pos,
472+
statistic=lambda x: [x[0], x[-1]],
473+
size=window_size,
474+
)
475+
476+
# access genotypes for all samples
477+
gt_data = ds_snps["call_genotype"].data
478+
479+
# Compute windowed heterozygosity for each sample and cache results
480+
results = {}
481+
for sample_id, sample_idx in sample_id_to_idx.items():
482+
# Compute heterozygous genotypes for this sample only to avoid
483+
# materializing the full (variants, samples) array in memory.
484+
debug(f"Compute heterozygous genotypes for sample {sample_id}")
485+
gt_sample = allel.GenotypeDaskVector(gt_data[:, sample_idx, :])
486+
with self._dask_progress(desc="Compute heterozygous genotypes"):
487+
is_het_sample = gt_sample.is_het().compute()
488+
489+
# compute windowed heterozygosity for this sample
490+
counts = allel.moving_statistic(
491+
values=is_het_sample,
492+
statistic=np.sum,
493+
size=window_size,
494+
)
495+
496+
results[sample_id] = (windows, counts)
497+
498+
return results
499+
398500
@property
399501
def _roh_hmm_cache_name(self):
400502
return "roh_hmm_v1"
@@ -816,18 +918,25 @@ def cohort_heterozygosity(
816918
)
817919
n_samples = len(df_cohort_samples)
818920

819-
# Compute heterozygosity for each sample and take the mean.
921+
# Compute heterozygosity for all samples in the cohort using cohort_count_het().
922+
# This public method loads SNP data once and computes across all samples,
923+
# providing substantial speedup over sequential per-sample processing.
924+
cohort_het_results = self.cohort_count_het(
925+
region=region_prepped,
926+
df_cohort_samples=df_cohort_samples,
927+
sample_sets=sample_sets,
928+
window_size=window_size,
929+
site_mask=site_mask,
930+
chunks=chunks,
931+
inline_array=inline_array,
932+
)
933+
934+
# Compute per-sample means and aggregate.
820935
het_values = []
821936
for sample_id in df_cohort_samples["sample_id"]:
822-
df_het = self.sample_count_het(
823-
sample=sample_id,
824-
region=region_prepped,
825-
window_size=window_size,
826-
site_mask=site_mask,
827-
chunks=chunks,
828-
inline_array=inline_array,
829-
)
830-
het_values.append(df_het["heterozygosity"].mean())
937+
_, counts = cohort_het_results[sample_id]
938+
het_mean = np.mean(counts / window_size)
939+
het_values.append(het_mean)
831940

832941
results.append(
833942
{

malariagen_data/util.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,9 @@ def __eq__(self, other):
570570
and (self.end == other.end)
571571
)
572572

573+
def __repr__(self):
574+
return f"Region({self._contig!r}, {self._start!r}, {self._end!r})"
575+
573576
def __str__(self):
574577
out = self._contig
575578
if self._start is not None or self._end is not None:
@@ -927,7 +930,20 @@ def _jitter(a, fraction, random_state=np.random):
927930

928931

929932
class CacheMiss(Exception):
930-
pass
933+
"""Raised when a requested item is not present in the cache."""
934+
935+
def __init__(self, key=None):
936+
self.key = key
937+
if key is not None:
938+
message = f"Cache miss for key: {key!r}"
939+
else:
940+
message = "Cache miss: requested item not found in cache."
941+
super().__init__(message)
942+
943+
def __repr__(self):
944+
if self.key is not None:
945+
return f"CacheMiss({self.key!r})"
946+
return "CacheMiss()"
931947

932948

933949
class LoggingHelper:
@@ -1531,12 +1547,10 @@ def _apply_allele_mapping(x, mapping, max_allele):
15311547

15321548
def _dask_apply_allele_mapping(v, mapping, max_allele):
15331549
if not isinstance(v, da.Array):
1534-
raise TypeError(
1535-
f"Expected v to be a dask.array.Array, " f"got {type(v).__name__}"
1536-
)
1550+
raise TypeError(f"Expected v to be a dask.array.Array, got {type(v).__name__}")
15371551
if not isinstance(mapping, np.ndarray):
15381552
raise TypeError(
1539-
f"Expected mapping to be a numpy.ndarray, " f"got {type(mapping).__name__}"
1553+
f"Expected mapping to be a numpy.ndarray, got {type(mapping).__name__}"
15401554
)
15411555
assert v.ndim == 2
15421556
assert mapping.ndim == 2
@@ -1558,12 +1572,10 @@ def _genotype_array_map_alleles(gt, mapping):
15581572
# N.B., scikit-allel does not handle empty blocks well, so we
15591573
# include some extra logic to handle that better.
15601574
if not isinstance(gt, np.ndarray):
1561-
raise TypeError(
1562-
f"Expected gt to be a numpy.ndarray, " f"got {type(gt).__name__}"
1563-
)
1575+
raise TypeError(f"Expected gt to be a numpy.ndarray, got {type(gt).__name__}")
15641576
if not isinstance(mapping, np.ndarray):
15651577
raise TypeError(
1566-
f"Expected mapping to be a numpy.ndarray, " f"got {type(mapping).__name__}"
1578+
f"Expected mapping to be a numpy.ndarray, got {type(mapping).__name__}"
15671579
)
15681580
assert gt.ndim == 3
15691581
assert mapping.ndim == 3
@@ -1585,11 +1597,11 @@ def _genotype_array_map_alleles(gt, mapping):
15851597
def _dask_genotype_array_map_alleles(gt, mapping):
15861598
if not isinstance(gt, da.Array):
15871599
raise TypeError(
1588-
f"Expected gt to be a dask.array.Array, " f"got {type(gt).__name__}"
1600+
f"Expected gt to be a dask.array.Array, got {type(gt).__name__}"
15891601
)
15901602
if not isinstance(mapping, np.ndarray):
15911603
raise TypeError(
1592-
f"Expected mapping to be a numpy.ndarray, " f"got {type(mapping).__name__}"
1604+
f"Expected mapping to be a numpy.ndarray, got {type(mapping).__name__}"
15931605
)
15941606
assert gt.ndim == 3
15951607
assert mapping.ndim == 2

tests/anoph/test_base.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,3 +411,28 @@ def test_sample_sets_no_terms_of_use(ag3_sim_fixture):
411411
finally:
412412
for mp, bp in zip(manifest_paths, backups):
413413
shutil.move(bp, mp)
414+
415+
416+
class TestSurveillanceFlagsBaseFallback:
417+
"""Tests for issue #1206: base _surveillance_flags graceful fallback."""
418+
419+
def test_surveillance_flags_base_returns_empty_and_warns(self, ag3_sim_api):
420+
"""Base implementation returns empty DataFrame with correct schema and warns."""
421+
with pytest.warns(UserWarning, match="Surveillance flags not implemented"):
422+
df = ag3_sim_api._surveillance_flags(sample_sets=["AG1000G-AO"])
423+
424+
assert isinstance(df, pd.DataFrame)
425+
assert list(df.columns) == ["sample_id", "is_surveillance"]
426+
assert df["sample_id"].dtype == object
427+
assert pd.api.types.is_bool_dtype(df["is_surveillance"])
428+
assert len(df) == 0
429+
430+
def test_sample_set_has_surveillance_data_returns_false_when_fallback(
431+
self, ag3_sim_api
432+
):
433+
"""_sample_set_has_surveillance_data returns False when base fallback is used."""
434+
with pytest.warns(UserWarning, match="Surveillance flags not implemented"):
435+
result = ag3_sim_api._sample_set_has_surveillance_data(
436+
sample_set="AG1000G-AO"
437+
)
438+
assert not result

0 commit comments

Comments
 (0)