|
| 1 | +import pytest |
| 2 | +from malariagen_data import ag3 as _ag3 |
| 3 | +from malariagen_data.anoph.distance import AnophelesDistanceAnalysis |
| 4 | + |
| 5 | + |
| 6 | +@pytest.fixture |
| 7 | +def ag3_sim_api(ag3_sim_fixture): |
| 8 | + return AnophelesDistanceAnalysis( |
| 9 | + url=ag3_sim_fixture.url, |
| 10 | + public_url=ag3_sim_fixture.url, |
| 11 | + config_path=_ag3.CONFIG_PATH, |
| 12 | + major_version_number=_ag3.MAJOR_VERSION_NUMBER, |
| 13 | + major_version_path=_ag3.MAJOR_VERSION_PATH, |
| 14 | + pre=True, |
| 15 | + aim_metadata_dtype={ |
| 16 | + "aim_species_fraction_arab": "float64", |
| 17 | + "aim_species_fraction_colu": "float64", |
| 18 | + "aim_species_fraction_colu_no2l": "float64", |
| 19 | + "aim_species_gambcolu_arabiensis": object, |
| 20 | + "aim_species_gambiae_coluzzii": object, |
| 21 | + "aim_species": object, |
| 22 | + }, |
| 23 | + gff_gene_type="gene", |
| 24 | + gff_gene_name_attribute="Name", |
| 25 | + gff_default_attributes=("ID", "Parent", "Name", "description"), |
| 26 | + default_site_mask="gamb_colu_arab", |
| 27 | + results_cache=ag3_sim_fixture.results_cache_path.as_posix(), |
| 28 | + taxon_colors=_ag3.TAXON_COLORS, |
| 29 | + virtual_contigs=_ag3.VIRTUAL_CONTIGS, |
| 30 | + ) |
| 31 | + |
| 32 | + |
| 33 | +def test_plot_njt_no_samples(ag3_sim_api): |
| 34 | + # Test with a query matching no samples. |
| 35 | + with pytest.raises(ValueError) as e: |
| 36 | + ag3_sim_api.plot_njt( |
| 37 | + region="2L", n_snps=10, sample_query="sex_call == 'Impossible_Value'" |
| 38 | + ) |
| 39 | + assert "No samples found for query" in str( |
| 40 | + e.value |
| 41 | + ) or "No relevant samples found" in str(e.value) |
| 42 | + |
| 43 | + |
| 44 | +def test_plot_njt_not_enough_snps(ag3_sim_api): |
| 45 | + # Request more SNPs than available in the region |
| 46 | + with pytest.raises(ValueError) as e: |
| 47 | + ag3_sim_api.plot_njt(region="2L", n_snps=10000000, sample_query=None) |
| 48 | + assert "Not enough SNPs." in str(e.value) |
| 49 | + assert "Found" in str(e.value) |
| 50 | + assert "needed 10000000" in str(e.value) |
| 51 | + |
| 52 | + |
| 53 | +def test_plot_njt_one_sample(ag3_sim_api): |
| 54 | + # Test with a query that returns only 1 sample. |
| 55 | + # This should trigger the minimum sample check in plot_njt. |
| 56 | + |
| 57 | + # First, find a sample so we can query for just one |
| 58 | + df_samples = ag3_sim_api.sample_metadata() |
| 59 | + sample_id = df_samples.iloc[0]["sample_id"] |
| 60 | + |
| 61 | + with pytest.raises(ValueError) as e: |
| 62 | + ag3_sim_api.plot_njt( |
| 63 | + region="2L", n_snps=10, sample_query=f"sample_id == '{sample_id}'" |
| 64 | + ) |
| 65 | + assert "Not enough samples for neighbour-joining tree" in str(e.value) |
| 66 | + assert "Found 1" in str(e.value) |
| 67 | + assert "needed at least 2" in str(e.value) |
0 commit comments