Skip to content

Commit c3212b3

Browse files
authored
Merge branch 'master' into GH-1054-add-vcf-export
2 parents a33425b + fed6010 commit c3212b3

12 files changed

Lines changed: 288 additions & 36 deletions

File tree

malariagen_data/adar1.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def _repr_html_(self):
181181
<th style="text-align: left">
182182
Data releases available
183183
</th>
184-
<td>{', '.join(self._available_releases)}</td>
184+
<td>{", ".join(self._available_releases)}</td>
185185
</tr>
186186
<tr>
187187
<th style="text-align: left">
@@ -229,7 +229,7 @@ def _repr_html_(self):
229229
<th style="text-align: left">
230230
Relevant data releases
231231
</th>
232-
<td>{', '.join(self.releases)}</td>
232+
<td>{", ".join(self.releases)}</td>
233233
</tr>
234234
</tbody>
235235
</table>

malariagen_data/adir1.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def _repr_html_(self):
181181
<th style="text-align: left">
182182
Data releases available
183183
</th>
184-
<td>{', '.join(self._available_releases)}</td>
184+
<td>{", ".join(self._available_releases)}</td>
185185
</tr>
186186
<tr>
187187
<th style="text-align: left">
@@ -229,7 +229,7 @@ def _repr_html_(self):
229229
<th style="text-align: left">
230230
Relevant data releases
231231
</th>
232-
<td>{', '.join(self.releases)}</td>
232+
<td>{", ".join(self.releases)}</td>
233233
</tr>
234234
</tbody>
235235
</table>

malariagen_data/af1.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def _repr_html_(self):
183183
<th style="text-align: left">
184184
Data releases available
185185
</th>
186-
<td>{', '.join(self._available_releases)}</td>
186+
<td>{", ".join(self._available_releases)}</td>
187187
</tr>
188188
<tr>
189189
<th style="text-align: left">
@@ -231,7 +231,7 @@ def _repr_html_(self):
231231
<th style="text-align: left">
232232
Relevant data releases
233233
</th>
234-
<td>{', '.join(self.releases)}</td>
234+
<td>{", ".join(self.releases)}</td>
235235
</tr>
236236
</tbody>
237237
</table>

malariagen_data/ag3.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ def _repr_html_(self):
278278
<th style="text-align: left">
279279
Data releases available
280280
</th>
281-
<td>{', '.join(self._available_releases)}</td>
281+
<td>{", ".join(self._available_releases)}</td>
282282
</tr>
283283
<tr>
284284
<th style="text-align: left">
@@ -332,7 +332,7 @@ def _repr_html_(self):
332332
<th style="text-align: left">
333333
Relevant data releases
334334
</th>
335-
<td>{', '.join(self.releases)}</td>
335+
<td>{", ".join(self.releases)}</td>
336336
</tr>
337337
</tbody>
338338
</table>

malariagen_data/amin1.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def _repr_html_(self):
181181
<th style="text-align: left">
182182
Data releases available
183183
</th>
184-
<td>{', '.join(self.releases)}</td>
184+
<td>{", ".join(self.releases)}</td>
185185
</tr>
186186
<tr>
187187
<th style="text-align: left">
@@ -229,7 +229,7 @@ def _repr_html_(self):
229229
<th style="text-align: left">
230230
Relevant data releases
231231
</th>
232-
<td>{', '.join(self.releases)}</td>
232+
<td>{", ".join(self.releases)}</td>
233233
</tr>
234234
</tbody>
235235
</table>

malariagen_data/anoph/base.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import warnings
23

34
import json
45
from contextlib import nullcontext
@@ -134,7 +135,7 @@ def __init__(
134135
storage_options = dict()
135136
try:
136137
self._fs, self._base_path = _init_filesystem(self._url, **storage_options)
137-
except Exception as exc: # pragma: no cover
138+
except (OSError, ImportError) as exc: # pragma: no cover
138139
raise IOError(
139140
"An error occurred establishing a connection to the storage system. Please see the nested exception for more details."
140141
) from exc
@@ -143,7 +144,7 @@ def __init__(
143144
try:
144145
with self.open_file(self._config_path) as f:
145146
self._config = json.load(f)
146-
except Exception as exc: # pragma: no cover
147+
except (OSError, json.JSONDecodeError) as exc: # pragma: no cover
147148
if (isinstance(exc, OSError) and "forbidden" in str(exc).lower()) or (
148149
getattr(exc, "status", None) == 403
149150
):
@@ -496,7 +497,20 @@ def client_location(self) -> str:
496497
return location
497498

498499
def _surveillance_flags(self, sample_sets: List[str]):
499-
raise NotImplementedError("Subclasses must implement `_surveillance_flags`.")
500+
"""Return surveillance flags for sample sets. Subclasses should override to
501+
load real data; this base implementation returns empty data and warns.
502+
"""
503+
warnings.warn(
504+
"Surveillance flags not implemented for this resource; returning empty data.",
505+
UserWarning,
506+
stacklevel=2,
507+
)
508+
return pd.DataFrame(
509+
{
510+
"sample_id": pd.Series(dtype="object"),
511+
"is_surveillance": pd.Series(dtype="boolean"),
512+
}
513+
)
500514

501515
def _release_has_unrestricted_data(self, *, release: str):
502516
"""Return `True` if the specified release has any unrestricted data. Otherwise return `False`."""

malariagen_data/anoph/heterozygosity.py

Lines changed: 119 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,108 @@ def _sample_count_het(
395395

396396
return sample_id, sample_set, windows, counts
397397

398+
def cohort_count_het(
399+
self,
400+
region: Region,
401+
df_cohort_samples: pd.DataFrame,
402+
sample_sets: Optional[base_params.sample_sets],
403+
window_size: het_params.window_size,
404+
site_mask: Optional[base_params.site_mask],
405+
chunks: base_params.chunks,
406+
inline_array: base_params.inline_array,
407+
):
408+
"""Compute windowed heterozygosity counts for multiple samples in a cohort.
409+
410+
This method efficiently computes heterozygosity for all samples by loading
411+
SNP data once and computing across all samples, rather than calling snp_calls()
412+
repeatedly for each sample. This vectorized approach provides substantial
413+
performance improvements for large cohorts.
414+
415+
Parameters
416+
----------
417+
region : Region
418+
Genome region to analyze.
419+
df_cohort_samples : pd.DataFrame
420+
Sample metadata dataframe with at least 'sample_id' column.
421+
sample_sets : str, optional
422+
Sample set identifier(s).
423+
window_size : int
424+
Size of sliding windows for heterozygosity computation.
425+
site_mask : str, optional
426+
Site mask to apply.
427+
chunks : str or int, dict
428+
Chunk size for dask arrays.
429+
inline_array : bool
430+
Whether to inline arrays.
431+
432+
Returns
433+
-------
434+
dict
435+
Mapping from sample_id to (windows, counts) tuple, where:
436+
- windows: array of shape (n_windows, 2) with [start, stop] positions
437+
- counts: array of shape (n_windows,) with heterozygous site counts per window
438+
"""
439+
debug = self._log.debug
440+
441+
# Extract sample IDs from cohort dataframe
442+
sample_ids = df_cohort_samples["sample_id"].values
443+
444+
debug("access SNPs for all cohort samples")
445+
# Load SNP data once for all samples in cohort
446+
ds_snps = self.snp_calls(
447+
region=region,
448+
sample_sets=sample_sets,
449+
site_mask=site_mask,
450+
chunks=chunks,
451+
inline_array=inline_array,
452+
)
453+
454+
# Subset to cohort samples to ensure correct indexing
455+
ds_snps = ds_snps.set_index(samples="sample_id").sel(samples=sample_ids)
456+
sample_id_to_idx = {sid: idx for idx, sid in enumerate(sample_ids)}
457+
458+
# SNP positions (same for all samples)
459+
pos = ds_snps["variant_position"].values
460+
461+
# guard against window_size exceeding available sites
462+
if pos.shape[0] < window_size:
463+
raise ValueError(
464+
f"Not enough sites ({pos.shape[0]}) for window size "
465+
f"({window_size}). Please reduce the window size or "
466+
f"use different site selection criteria."
467+
)
468+
469+
# Compute window coordinates once (same for all samples)
470+
windows = allel.moving_statistic(
471+
values=pos,
472+
statistic=lambda x: [x[0], x[-1]],
473+
size=window_size,
474+
)
475+
476+
# access genotypes for all samples
477+
gt_data = ds_snps["call_genotype"].data
478+
479+
# Compute windowed heterozygosity for each sample and cache results
480+
results = {}
481+
for sample_id, sample_idx in sample_id_to_idx.items():
482+
# Compute heterozygous genotypes for this sample only to avoid
483+
# materializing the full (variants, samples) array in memory.
484+
debug(f"Compute heterozygous genotypes for sample {sample_id}")
485+
gt_sample = allel.GenotypeDaskVector(gt_data[:, sample_idx, :])
486+
with self._dask_progress(desc="Compute heterozygous genotypes"):
487+
is_het_sample = gt_sample.is_het().compute()
488+
489+
# compute windowed heterozygosity for this sample
490+
counts = allel.moving_statistic(
491+
values=is_het_sample,
492+
statistic=np.sum,
493+
size=window_size,
494+
)
495+
496+
results[sample_id] = (windows, counts)
497+
498+
return results
499+
398500
@property
399501
def _roh_hmm_cache_name(self):
400502
return "roh_hmm_v1"
@@ -816,18 +918,25 @@ def cohort_heterozygosity(
816918
)
817919
n_samples = len(df_cohort_samples)
818920

819-
# Compute heterozygosity for each sample and take the mean.
921+
# Compute heterozygosity for all samples in the cohort using cohort_count_het().
922+
# This public method loads SNP data once and computes across all samples,
923+
# providing substantial speedup over sequential per-sample processing.
924+
cohort_het_results = self.cohort_count_het(
925+
region=region_prepped,
926+
df_cohort_samples=df_cohort_samples,
927+
sample_sets=sample_sets,
928+
window_size=window_size,
929+
site_mask=site_mask,
930+
chunks=chunks,
931+
inline_array=inline_array,
932+
)
933+
934+
# Compute per-sample means and aggregate.
820935
het_values = []
821936
for sample_id in df_cohort_samples["sample_id"]:
822-
df_het = self.sample_count_het(
823-
sample=sample_id,
824-
region=region_prepped,
825-
window_size=window_size,
826-
site_mask=site_mask,
827-
chunks=chunks,
828-
inline_array=inline_array,
829-
)
830-
het_values.append(df_het["heterozygosity"].mean())
937+
_, counts = cohort_het_results[sample_id]
938+
het_mean = np.mean(counts / window_size)
939+
het_values.append(het_mean)
831940

832941
results.append(
833942
{

malariagen_data/anoph/map_params.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def get_basemap_abbrevs() -> dict:
4242
for key, provider_fn in _basemap_abbrev_candidates.items():
4343
try:
4444
_basemap_abbrevs[key] = provider_fn()
45-
except Exception:
45+
except (ImportError, AttributeError):
4646
warnings.warn(
4747
f"Basemap provider {key!r} is not available and will be skipped.",
4848
stacklevel=2,

malariagen_data/anoph/sample_metadata.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1864,7 +1864,7 @@ def _locate_cohorts(*, cohorts, data, min_cohort_size):
18641864
for coh, query in cohorts.items():
18651865
try:
18661866
loc_coh = data.eval(query).values
1867-
except Exception as e:
1867+
except (KeyError, NameError, SyntaxError, TypeError, AttributeError) as e:
18681868
raise ValueError(
18691869
f"Invalid query for cohort {coh!r}: {query!r}. Error: {e}"
18701870
) from e

malariagen_data/anopheles.py

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from abc import abstractmethod
21
from typing import Any, Dict, Mapping, Optional, Tuple, Sequence
32

43
import allel # type: ignore
@@ -182,15 +181,47 @@ def __init__(
182181
surveillance_use_only=surveillance_use_only,
183182
)
184183

185-
@property
186-
@abstractmethod
187-
def _xpehh_gwss_cache_name(self):
188-
raise NotImplementedError("Must override _xpehh_gwss_cache_name")
184+
def _get_xpehh_gwss_cache_name(self):
185+
"""Safely resolve the xpehh gwss cache name.
189186
190-
@property
191-
@abstractmethod
192-
def _ihs_gwss_cache_name(self):
193-
raise NotImplementedError("Must override _ihs_gwss_cache_name")
187+
Supports class attribute, property, or legacy method override.
188+
Falls back to the default "xpehh_gwss_v1" if resolution fails.
189+
190+
See also: https://github.com/malariagen/malariagen-data-python/issues/1151
191+
"""
192+
try:
193+
name = self._xpehh_gwss_cache_name
194+
# Handle legacy case where _xpehh_gwss_cache_name might be a
195+
# callable method rather than a property or class attribute.
196+
if callable(name):
197+
name = name()
198+
if isinstance(name, str) and len(name) > 0:
199+
return name
200+
except NotImplementedError:
201+
pass
202+
# Fallback to default.
203+
return "xpehh_gwss_v1"
204+
205+
def _get_ihs_gwss_cache_name(self):
206+
"""Safely resolve the ihs gwss cache name.
207+
208+
Supports class attribute, property, or legacy method override.
209+
Falls back to the default "ihs_gwss_v1" if resolution fails.
210+
211+
See also: https://github.com/malariagen/malariagen-data-python/issues/1151
212+
"""
213+
try:
214+
name = self._ihs_gwss_cache_name
215+
# Handle legacy case where _ihs_gwss_cache_name might be a
216+
# callable method rather than a property or class attribute.
217+
if callable(name):
218+
name = name()
219+
if isinstance(name, str) and len(name) > 0:
220+
return name
221+
except NotImplementedError:
222+
pass
223+
# Fallback to default.
224+
return "ihs_gwss_v1"
194225

195226
@staticmethod
196227
def _make_gene_cnv_label(gene_id, gene_name, cnv_type):
@@ -727,7 +758,7 @@ def ihs_gwss(
727758
) -> Tuple[np.ndarray, np.ndarray]:
728759
# change this name if you ever change the behaviour of this function, to
729760
# invalidate any previously cached data
730-
name = self._ihs_gwss_cache_name
761+
name = self._get_ihs_gwss_cache_name()
731762

732763
params = dict(
733764
contig=contig,
@@ -1251,7 +1282,7 @@ def xpehh_gwss(
12511282
) -> Tuple[np.ndarray, np.ndarray]:
12521283
# change this name if you ever change the behaviour of this function, to
12531284
# invalidate any previously cached data
1254-
name = self._xpehh_gwss_cache_name
1285+
name = self._get_xpehh_gwss_cache_name()
12551286

12561287
params = dict(
12571288
contig=contig,

0 commit comments

Comments
 (0)