Skip to content

Commit 75ab467

Browse files
Improve plot_njt errors for insufficient data
Signed-off-by: Aryan-SINGH-GIT <aryansingh12oct2005@gmail.com>
1 parent bef737a commit 75ab467

3 files changed

Lines changed: 86 additions & 1 deletion

File tree

malariagen_data/anoph/distance.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -477,11 +477,24 @@ def plot_njt(
477477
# functions.
478478
import anjl # type: ignore
479479

480+
480481
# Normalise params.
481482
if count_sort is None and distance_sort is None:
482483
count_sort = True
483484
distance_sort = False
484485

486+
# Ensure we have enough samples for a tree.
487+
# If we have 0 samples, `biallelic_snp_calls` or `snp_calls` should have already raised "No samples found".
488+
# However, if we have 1 sample, it might pass through until here, where it would cause a failure in njt.
489+
df_samples = self.sample_metadata(
490+
sample_sets=sample_sets,
491+
sample_query=sample_query,
492+
sample_query_options=sample_query_options,
493+
sample_indices=sample_indices,
494+
)
495+
if 0 < len(df_samples) < 2:
496+
raise ValueError(f"Not enough samples for neighbour-joining tree. Found {len(df_samples)}, needed at least 2.")
497+
485498
# Compute neighbour-joining tree.
486499
Z, samples, n_snps_used = self.njt(
487500
region=region,

malariagen_data/anoph/snp_data.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1886,7 +1886,9 @@ def biallelic_snp_calls(
18861886
ds_out = ds_out.isel(variants=loc_thin)
18871887

18881888
elif ds_out.sizes["variants"] < n_snps:
1889-
raise ValueError("Not enough SNPs.")
1889+
raise ValueError(
1890+
f"Not enough SNPs. Found {ds_out.sizes['variants']}, needed {n_snps}."
1891+
)
18901892

18911893
return ds_out
18921894

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
2+
import pytest
3+
from malariagen_data import ag3 as _ag3
4+
from malariagen_data.anoph.distance import AnophelesDistanceAnalysis
5+
6+
@pytest.fixture
7+
def ag3_sim_api(ag3_sim_fixture):
8+
return AnophelesDistanceAnalysis(
9+
url=ag3_sim_fixture.url,
10+
public_url=ag3_sim_fixture.url,
11+
config_path=_ag3.CONFIG_PATH,
12+
major_version_number=_ag3.MAJOR_VERSION_NUMBER,
13+
major_version_path=_ag3.MAJOR_VERSION_PATH,
14+
pre=True,
15+
aim_metadata_dtype={
16+
"aim_species_fraction_arab": "float64",
17+
"aim_species_fraction_colu": "float64",
18+
"aim_species_fraction_colu_no2l": "float64",
19+
"aim_species_gambcolu_arabiensis": object,
20+
"aim_species_gambiae_coluzzii": object,
21+
"aim_species": object,
22+
},
23+
gff_gene_type="gene",
24+
gff_gene_name_attribute="Name",
25+
gff_default_attributes=("ID", "Parent", "Name", "description"),
26+
default_site_mask="gamb_colu_arab",
27+
results_cache=ag3_sim_fixture.results_cache_path.as_posix(),
28+
taxon_colors=_ag3.TAXON_COLORS,
29+
virtual_contigs=_ag3.VIRTUAL_CONTIGS,
30+
)
31+
32+
def test_plot_njt_no_samples(ag3_sim_api):
33+
# Test with a query matching no samples.
34+
with pytest.raises(ValueError) as e:
35+
ag3_sim_api.plot_njt(
36+
region="2L",
37+
n_snps=10,
38+
sample_query="sex_call == 'Impossible_Value'"
39+
)
40+
assert "No samples found for query" in str(e.value) or "No relevant samples found" in str(e.value)
41+
42+
def test_plot_njt_not_enough_snps(ag3_sim_api):
43+
# Request more SNPs than available in the region
44+
with pytest.raises(ValueError) as e:
45+
ag3_sim_api.plot_njt(
46+
region="2L",
47+
n_snps=10000000,
48+
sample_query=None
49+
)
50+
assert "Not enough SNPs." in str(e.value)
51+
assert "Found" in str(e.value)
52+
assert "needed 10000000" in str(e.value)
53+
54+
def test_plot_njt_one_sample(ag3_sim_api):
55+
# Test with a query that returns only 1 sample.
56+
# This should trigger the minimum sample check in plot_njt.
57+
58+
# First, find a sample so we can query for just one
59+
df_samples = ag3_sim_api.sample_metadata()
60+
sample_id = df_samples.iloc[0]['sample_id']
61+
62+
with pytest.raises(ValueError) as e:
63+
ag3_sim_api.plot_njt(
64+
region="2L",
65+
n_snps=10,
66+
sample_query=f"sample_id == '{sample_id}'"
67+
)
68+
assert "Not enough samples for neighbour-joining tree" in str(e.value)
69+
assert "Found 1" in str(e.value)
70+
assert "needed at least 2" in str(e.value)

0 commit comments

Comments
 (0)