Skip to content

Commit d9dc96e

Browse files
Improve plot_njt errors for insufficient data
Signed-off-by: Aryan-SINGH-GIT <aryansingh12oct2005@gmail.com>
1 parent bef737a commit d9dc96e

3 files changed

Lines changed: 84 additions & 1 deletion

File tree

malariagen_data/anoph/distance.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,20 @@ def plot_njt(
482482
count_sort = True
483483
distance_sort = False
484484

485+
# Ensure we have enough samples for a tree.
486+
# If we have 0 samples, `biallelic_snp_calls` or `snp_calls` should have already raised "No samples found".
487+
# However, if we have 1 sample, it might pass through until here, where it would cause a failure in njt.
488+
df_samples = self.sample_metadata(
489+
sample_sets=sample_sets,
490+
sample_query=sample_query,
491+
sample_query_options=sample_query_options,
492+
sample_indices=sample_indices,
493+
)
494+
if 0 < len(df_samples) < 2:
495+
raise ValueError(
496+
f"Not enough samples for neighbour-joining tree. Found {len(df_samples)}, needed at least 2."
497+
)
498+
485499
# Compute neighbour-joining tree.
486500
Z, samples, n_snps_used = self.njt(
487501
region=region,

malariagen_data/anoph/snp_data.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1886,7 +1886,9 @@ def biallelic_snp_calls(
18861886
ds_out = ds_out.isel(variants=loc_thin)
18871887

18881888
elif ds_out.sizes["variants"] < n_snps:
1889-
raise ValueError("Not enough SNPs.")
1889+
raise ValueError(
1890+
f"Not enough SNPs. Found {ds_out.sizes['variants']}, needed {n_snps}."
1891+
)
18901892

18911893
return ds_out
18921894

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import pytest
2+
from malariagen_data import ag3 as _ag3
3+
from malariagen_data.anoph.distance import AnophelesDistanceAnalysis
4+
5+
6+
@pytest.fixture
7+
def ag3_sim_api(ag3_sim_fixture):
8+
return AnophelesDistanceAnalysis(
9+
url=ag3_sim_fixture.url,
10+
public_url=ag3_sim_fixture.url,
11+
config_path=_ag3.CONFIG_PATH,
12+
major_version_number=_ag3.MAJOR_VERSION_NUMBER,
13+
major_version_path=_ag3.MAJOR_VERSION_PATH,
14+
pre=True,
15+
aim_metadata_dtype={
16+
"aim_species_fraction_arab": "float64",
17+
"aim_species_fraction_colu": "float64",
18+
"aim_species_fraction_colu_no2l": "float64",
19+
"aim_species_gambcolu_arabiensis": object,
20+
"aim_species_gambiae_coluzzii": object,
21+
"aim_species": object,
22+
},
23+
gff_gene_type="gene",
24+
gff_gene_name_attribute="Name",
25+
gff_default_attributes=("ID", "Parent", "Name", "description"),
26+
default_site_mask="gamb_colu_arab",
27+
results_cache=ag3_sim_fixture.results_cache_path.as_posix(),
28+
taxon_colors=_ag3.TAXON_COLORS,
29+
virtual_contigs=_ag3.VIRTUAL_CONTIGS,
30+
)
31+
32+
33+
def test_plot_njt_no_samples(ag3_sim_api):
34+
# Test with a query matching no samples.
35+
with pytest.raises(ValueError) as e:
36+
ag3_sim_api.plot_njt(
37+
region="2L", n_snps=10, sample_query="sex_call == 'Impossible_Value'"
38+
)
39+
assert "No samples found for query" in str(
40+
e.value
41+
) or "No relevant samples found" in str(e.value)
42+
43+
44+
def test_plot_njt_not_enough_snps(ag3_sim_api):
45+
# Request more SNPs than available in the region
46+
with pytest.raises(ValueError) as e:
47+
ag3_sim_api.plot_njt(region="2L", n_snps=10000000, sample_query=None)
48+
assert "Not enough SNPs." in str(e.value)
49+
assert "Found" in str(e.value)
50+
assert "needed 10000000" in str(e.value)
51+
52+
53+
def test_plot_njt_one_sample(ag3_sim_api):
54+
# Test with a query that returns only 1 sample.
55+
# This should trigger the minimum sample check in plot_njt.
56+
57+
# First, find a sample so we can query for just one
58+
df_samples = ag3_sim_api.sample_metadata()
59+
sample_id = df_samples.iloc[0]["sample_id"]
60+
61+
with pytest.raises(ValueError) as e:
62+
ag3_sim_api.plot_njt(
63+
region="2L", n_snps=10, sample_query=f"sample_id == '{sample_id}'"
64+
)
65+
assert "Not enough samples for neighbour-joining tree" in str(e.value)
66+
assert "Found 1" in str(e.value)
67+
assert "needed at least 2" in str(e.value)

0 commit comments

Comments
 (0)