Skip to content

Commit c768147

Browse files
Merge branch 'master' into optimize/cohort-heterozygosity-vectorized
2 parents 5e5fc44 + 22dcc29 commit c768147

8 files changed

Lines changed: 730 additions & 53 deletions

File tree

malariagen_data/anoph/heterozygosity.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,27 @@ def cohort_count_het(
501501
def _roh_hmm_cache_name(self):
502502
return "roh_hmm_v1"
503503

504+
def _get_roh_hmm_cache_name(self):
505+
"""Safely resolve the ROH HMM cache name.
506+
507+
Supports class attribute, property, or legacy method override.
508+
Falls back to the default "roh_hmm_v1" if resolution fails.
509+
510+
See also: https://github.com/malariagen/malariagen-data-python/issues/1151
511+
"""
512+
try:
513+
name = self._roh_hmm_cache_name
514+
# Handle legacy case where _roh_hmm_cache_name might be a
515+
# callable method rather than a property or class attribute.
516+
if callable(name):
517+
name = name()
518+
if isinstance(name, str) and len(name) > 0:
519+
return name
520+
except NotImplementedError:
521+
pass
522+
# Fallback to default.
523+
return "roh_hmm_v1"
524+
504525
@_check_types
505526
@doc(
506527
summary="Infer runs of homozygosity for a single sample over a genome region.",
@@ -522,7 +543,7 @@ def roh_hmm(
522543

523544
resolved_region: Region = _parse_single_region(self, region)
524545

525-
name = self._roh_hmm_cache_name
546+
name = self._get_roh_hmm_cache_name()
526547

527548
params = dict(
528549
sample=sample,

malariagen_data/anoph/pca.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -399,9 +399,9 @@ def plot_pca_coords(
399399

400400
# Apply jitter if desired - helps spread out points when tightly clustered.
401401
if jitter_frac:
402-
np.random.seed(random_seed)
403-
data[x] = _jitter(data[x], jitter_frac)
404-
data[y] = _jitter(data[y], jitter_frac)
402+
rng = np.random.default_rng(seed=random_seed)
403+
data[x] = _jitter(data[x], jitter_frac, random_state=rng)
404+
data[y] = _jitter(data[y], jitter_frac, random_state=rng)
405405

406406
# Convenience variables.
407407
# Prevent lint error (mypy): Unsupported operand types for + ("Series[Any]" and "str")
@@ -503,10 +503,10 @@ def plot_pca_coords_3d(
503503

504504
# Apply jitter if desired - helps spread out points when tightly clustered.
505505
if jitter_frac:
506-
np.random.seed(random_seed)
507-
data[x] = _jitter(data[x], jitter_frac)
508-
data[y] = _jitter(data[y], jitter_frac)
509-
data[z] = _jitter(data[z], jitter_frac)
506+
rng = np.random.default_rng(seed=random_seed)
507+
data[x] = _jitter(data[x], jitter_frac, random_state=rng)
508+
data[y] = _jitter(data[y], jitter_frac, random_state=rng)
509+
data[z] = _jitter(data[z], jitter_frac, random_state=rng)
510510

511511
# Convenience variables.
512512
# Prevent lint error (mypy): Unsupported operand types for + ("Series[Any]" and "str")

malariagen_data/util.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -897,10 +897,33 @@ def _hash_params(params):
897897
return h, s
898898

899899

900-
def _jitter(a, fraction):
901-
"""Jitter data in `a` using the fraction `f`."""
900+
def _jitter(a, fraction, random_state=np.random):
901+
"""Jitter data by adding uniform noise scaled by the data range.
902+
903+
Parameters
904+
----------
905+
a : array-like
906+
Input data to jitter. Can be a numpy array or pandas Series.
907+
fraction : float
908+
Controls the amplitude of the jitter relative to the data range.
909+
random_state : numpy.random.Generator or module, optional
910+
Random number generator to use. Accepts a ``numpy.random.Generator``
911+
(from ``np.random.default_rng()``) or the ``numpy.random`` module.
912+
Defaults to ``np.random`` (global RNG) for backward compatibility.
913+
914+
Returns
915+
-------
916+
array-like
917+
Jittered copy of the input data with the same shape and type.
918+
919+
Notes
920+
-----
921+
Prefer passing a local ``np.random.default_rng(seed=...)`` to avoid
922+
mutating global RNG state and to ensure reproducibility.
923+
924+
"""
902925
r = a.max() - a.min()
903-
return a + fraction * np.random.uniform(-r, r, a.shape)
926+
return a + fraction * random_state.uniform(-r, r, a.shape)
904927

905928

906929
class CacheMiss(Exception):

tests/anoph/test_fst.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -346,8 +346,16 @@ def test_pairwise_average_fst_with_sample_query(fixture, api: AnophelesFstAnalys
346346
n_jack=random.randint(10, 200),
347347
)
348348

349-
# Run checks.
350-
check_pairwise_average_fst(api=api, fst_params=fst_params)
349+
# Run checks - skip if random parameter selection results in insufficient cohorts.
350+
try:
351+
check_pairwise_average_fst(api=api, fst_params=fst_params)
352+
except ValueError as e:
353+
if "No cohorts remain" in str(e):
354+
pytest.skip(
355+
f"Skipping: random parameter selection produced insufficient "
356+
f"cohorts for taxon={taxon!r}: {e}"
357+
)
358+
raise
351359

352360

353361
@parametrize_with_cases("fixture,api", cases=".")

tests/anoph/test_hap_frq.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,12 @@ def test_hap_frequencies_with_str_cohorts(
187187
return
188188

189189
# Run the function under test.
190-
df_hap = api.haplotypes_frequencies(**params)
190+
try:
191+
df_hap = api.haplotypes_frequencies(**params)
192+
except ValueError as e:
193+
if "No SNPs available for the given region" in str(e):
194+
pytest.skip("Random region contained no SNPs")
195+
raise
191196

192197
check_plot_frequencies_heatmap(api, df_hap)
193198

tests/anoph/test_heterozygosity.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,19 @@ def test_roh_hmm(fixture, api: AnophelesHetAnalysis):
186186
assert col in df_roh.columns
187187

188188

189+
@parametrize_with_cases("fixture,api", cases=".")
190+
def test_roh_hmm_cache_name_resolution(fixture, api: AnophelesHetAnalysis):
191+
"""Regression test for GH#1151: _roh_hmm_cache_name must resolve to a string.
192+
193+
Verifies that the cache name resolver handles class attributes, properties,
194+
and legacy method overrides without raising NotImplementedError.
195+
"""
196+
# The resolver should always return a non-empty string.
197+
name = api._get_roh_hmm_cache_name()
198+
assert isinstance(name, str), f"Expected str, got {type(name)}"
199+
assert len(name) > 0, "Cache name must be non-empty"
200+
201+
189202
@parametrize_with_cases("fixture,api", cases=".")
190203
def test_plot_roh(fixture, api: AnophelesHetAnalysis):
191204
# Set up test parameters.

tests/anoph/test_pca.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,3 +313,52 @@ def test_pca_fit_exclude_samples(fixture, api: AnophelesPca):
313313
len(pca_df.query(f"sample_id in {exclude_samples} and not pca_fit"))
314314
== n_samples_excluded
315315
)
316+
317+
318+
# --- _jitter() determinism unit tests ---
319+
320+
321+
def test_jitter_determinism():
322+
"""_jitter with the same seed must produce identical results."""
323+
from malariagen_data.util import _jitter
324+
325+
a = np.array([1.0, 2.0, 3.0, 4.0, 5.0])
326+
fraction = 0.1
327+
328+
rng1 = np.random.default_rng(seed=42)
329+
result1 = _jitter(a, fraction, random_state=rng1)
330+
331+
rng2 = np.random.default_rng(seed=42)
332+
result2 = _jitter(a, fraction, random_state=rng2)
333+
334+
np.testing.assert_array_equal(result1, result2)
335+
336+
337+
def test_jitter_different_seeds():
338+
"""_jitter with different seeds must produce different results."""
339+
from malariagen_data.util import _jitter
340+
341+
a = np.array([1.0, 2.0, 3.0, 4.0, 5.0])
342+
fraction = 0.1
343+
344+
rng1 = np.random.default_rng(seed=42)
345+
result1 = _jitter(a, fraction, random_state=rng1)
346+
347+
rng2 = np.random.default_rng(seed=99)
348+
result2 = _jitter(a, fraction, random_state=rng2)
349+
350+
assert not np.array_equal(result1, result2)
351+
352+
353+
def test_jitter_no_global_rng_side_effect():
354+
"""_jitter with explicit random_state must not alter global RNG state."""
355+
from malariagen_data.util import _jitter
356+
357+
np.random.seed(0)
358+
state_before = np.random.get_state()[1].copy()
359+
360+
rng = np.random.default_rng(seed=42)
361+
_jitter(np.array([1.0, 2.0, 3.0]), 0.1, random_state=rng)
362+
363+
state_after = np.random.get_state()[1].copy()
364+
np.testing.assert_array_equal(state_before, state_after)

0 commit comments

Comments
 (0)