diff --git a/malariagen_data/anoph/pbs.py b/malariagen_data/anoph/pbs.py new file mode 100644 index 000000000..d0059b60b --- /dev/null +++ b/malariagen_data/anoph/pbs.py @@ -0,0 +1,417 @@ +import warnings +from typing import Tuple, Optional + +import numpy as np +import allel # type: ignore +from numpydoc_decorator import doc # type: ignore +import bokeh.models +import bokeh.plotting +import bokeh.layouts + +from .snp_data import AnophelesSnpData +from . import base_params, pbs_params, gplt_params +from ..util import CacheMiss, _check_types + + +class AnophelesPbsAnalysis( + AnophelesSnpData, +): + def __init__( + self, + **kwargs, + ): + # N.B., this class is designed to work cooperatively, and + # so it's important that any remaining parameters are passed + # to the superclass constructor. + super().__init__(**kwargs) + + def _pbs_gwss( + self, + *, + contig, + window_size, + sample_sets, + cohort1_query, + cohort2_query, + cohort3_query, + sample_query_options, + site_mask, + cohort_size, + min_cohort_size, + max_cohort_size, + normed, + random_seed, + inline_array, + chunks, + min_snps_threshold, + window_adjustment_factor, + ): + # Compute allele counts for cohort 1 (focal population). + ac1 = self.snp_allele_counts( + region=contig, + sample_query=cohort1_query, + sample_query_options=sample_query_options, + sample_sets=sample_sets, + site_mask=site_mask, + cohort_size=cohort_size, + min_cohort_size=min_cohort_size, + max_cohort_size=max_cohort_size, + random_seed=random_seed, + inline_array=inline_array, + chunks=chunks, + ) + # Compute allele counts for cohort 2. + ac2 = self.snp_allele_counts( + region=contig, + sample_query=cohort2_query, + sample_query_options=sample_query_options, + sample_sets=sample_sets, + site_mask=site_mask, + cohort_size=cohort_size, + min_cohort_size=min_cohort_size, + max_cohort_size=max_cohort_size, + random_seed=random_seed, + inline_array=inline_array, + chunks=chunks, + ) + # Compute allele counts for cohort 3 (outgroup). + ac3 = self.snp_allele_counts( + region=contig, + sample_query=cohort3_query, + sample_query_options=sample_query_options, + sample_sets=sample_sets, + site_mask=site_mask, + cohort_size=cohort_size, + min_cohort_size=min_cohort_size, + max_cohort_size=max_cohort_size, + random_seed=random_seed, + inline_array=inline_array, + chunks=chunks, + ) + + with self._spinner(desc="Load SNP positions"): + pos = self.snp_sites( + region=contig, + field="POS", + site_mask=site_mask, + inline_array=inline_array, + chunks=chunks, + ).compute() + + n_snps = len(pos) + if n_snps < min_snps_threshold: + raise ValueError( + f"Too few SNP sites ({n_snps}) available for PBS GWSS. " + f"At least {min_snps_threshold} sites are required. " + "Try a larger genomic region or different site selection criteria." + ) + if window_size >= n_snps: + adjusted_window_size = max(1, n_snps // window_adjustment_factor) + warnings.warn( + f"window_size ({window_size}) is >= the number of SNP sites " + f"available ({n_snps}); automatically adjusting window_size to " + f"{adjusted_window_size} (= {n_snps} // {window_adjustment_factor}).", + UserWarning, + stacklevel=2, + ) + window_size = adjusted_window_size + + with self._spinner(desc="Compute PBS"): + with np.errstate(divide="ignore", invalid="ignore"): + pbs = allel.pbs( + ac1=ac1, + ac2=ac2, + ac3=ac3, + window_size=window_size, + normed=normed, + ) + x = allel.moving_statistic(pos, statistic=np.mean, size=window_size) + + results = dict(x=x, pbs=pbs) + + return results + + @_check_types + @doc( + summary=""" + Run a PBS genome-wide scan to detect lineage-specific selection + in a focal population relative to two other populations. + Uses the Population Branch Statistic (Yi et al. 2010). + If window_size is >= the number of available SNP sites, a + UserWarning is issued and window_size is automatically adjusted. + A ValueError is raised if the number of available SNP sites is + below min_snps_threshold. + """, + parameters=dict( + min_snps_threshold=""" + Minimum number of SNP sites required. If fewer sites are + available a ValueError is raised. + """, + window_adjustment_factor=""" + If window_size is >= the number of available SNP sites, + window_size is automatically set to + number_of_snps // window_adjustment_factor. + """, + ), + returns=dict( + x="An array containing the window centre point genomic positions.", + pbs="An array with PBS statistic values for each window.", + ), + ) + def pbs_gwss( + self, + contig: base_params.contig, + window_size: pbs_params.window_size, + cohort1_query: base_params.sample_query, + cohort2_query: base_params.sample_query, + cohort3_query: base_params.sample_query, + normed: pbs_params.normed = pbs_params.normed_default, + sample_query_options: Optional[base_params.sample_query_options] = None, + sample_sets: Optional[base_params.sample_sets] = None, + site_mask: Optional[base_params.site_mask] = base_params.DEFAULT, + cohort_size: Optional[base_params.cohort_size] = pbs_params.cohort_size_default, + min_cohort_size: Optional[ + base_params.min_cohort_size + ] = pbs_params.min_cohort_size_default, + max_cohort_size: Optional[ + base_params.max_cohort_size + ] = pbs_params.max_cohort_size_default, + random_seed: base_params.random_seed = 42, + inline_array: base_params.inline_array = base_params.inline_array_default, + chunks: base_params.chunks = base_params.native_chunks, + min_snps_threshold: pbs_params.min_snps_threshold = 1000, + window_adjustment_factor: pbs_params.window_adjustment_factor = 10, + ) -> Tuple[np.ndarray, np.ndarray]: + # Change this name if you ever change the behaviour of this function, to + # invalidate any previously cached data. + name = "pbs_gwss_v1" + + params = dict( + contig=contig, + window_size=window_size, + cohort1_query=self._prep_sample_query_param(sample_query=cohort1_query), + cohort2_query=self._prep_sample_query_param(sample_query=cohort2_query), + cohort3_query=self._prep_sample_query_param(sample_query=cohort3_query), + normed=normed, + sample_query_options=sample_query_options, + sample_sets=self._prep_sample_sets_param(sample_sets=sample_sets), + site_mask=self._prep_optional_site_mask_param(site_mask=site_mask), + cohort_size=cohort_size, + min_cohort_size=min_cohort_size, + max_cohort_size=max_cohort_size, + random_seed=random_seed, + ) + + try: + results = self.results_cache_get(name=name, params=params) + + except CacheMiss: + results = self._pbs_gwss( + **params, + inline_array=inline_array, + chunks=chunks, + min_snps_threshold=min_snps_threshold, + window_adjustment_factor=window_adjustment_factor, + ) + self.results_cache_set(name=name, params=params, results=results) + + x = results["x"] + pbs = results["pbs"] + + return x, pbs + + @_check_types + @doc( + summary=""" + Run and plot a PBS genome-wide scan to detect lineage-specific + selection in a focal population. + """, + ) + def plot_pbs_gwss_track( + self, + contig: base_params.contig, + window_size: pbs_params.window_size, + cohort1_query: base_params.sample_query, + cohort2_query: base_params.sample_query, + cohort3_query: base_params.sample_query, + normed: pbs_params.normed = pbs_params.normed_default, + sample_query_options: Optional[base_params.sample_query_options] = None, + sample_sets: Optional[base_params.sample_sets] = None, + site_mask: Optional[base_params.site_mask] = base_params.DEFAULT, + cohort_size: Optional[base_params.cohort_size] = pbs_params.cohort_size_default, + min_cohort_size: Optional[ + base_params.min_cohort_size + ] = pbs_params.min_cohort_size_default, + max_cohort_size: Optional[ + base_params.max_cohort_size + ] = pbs_params.max_cohort_size_default, + random_seed: base_params.random_seed = 42, + title: Optional[gplt_params.title] = None, + sizing_mode: gplt_params.sizing_mode = gplt_params.sizing_mode_default, + width: gplt_params.width = gplt_params.width_default, + height: gplt_params.height = 200, + show: gplt_params.show = True, + x_range: Optional[gplt_params.x_range] = None, + output_backend: gplt_params.output_backend = gplt_params.output_backend_default, + ) -> gplt_params.optional_figure: + # Compute PBS. + x, pbs = self.pbs_gwss( + contig=contig, + window_size=window_size, + cohort_size=cohort_size, + min_cohort_size=min_cohort_size, + max_cohort_size=max_cohort_size, + cohort1_query=cohort1_query, + cohort2_query=cohort2_query, + cohort3_query=cohort3_query, + normed=normed, + sample_query_options=sample_query_options, + sample_sets=sample_sets, + site_mask=site_mask, + random_seed=random_seed, + ) + + # Determine X axis range. + x_min = x[0] + x_max = x[-1] + if x_range is None: + x_range = bokeh.models.Range1d(x_min, x_max, bounds="auto") + + # Create a figure. + xwheel_zoom = bokeh.models.WheelZoomTool( + dimensions="width", maintain_focus=False + ) + if title is None: + title = ( + f"Focal: {cohort1_query}\n" + f"Comparison: {cohort2_query}\n" + f"Outgroup: {cohort3_query}" + ) + fig = bokeh.plotting.figure( + title=title, + tools=[ + "xpan", + "xzoom_in", + "xzoom_out", + xwheel_zoom, + "reset", + "save", + "crosshair", + ], + active_inspect=None, + active_scroll=xwheel_zoom, + active_drag="xpan", + sizing_mode=sizing_mode, + width=width, + height=height, + toolbar_location="above", + x_range=x_range, + output_backend=output_backend, + ) + + # Plot PBS. + fig.scatter( + x=x, + y=pbs, + size=3, + marker="circle", + line_width=1, + line_color="black", + fill_color=None, + ) + + # Tidy up the plot. + fig.yaxis.axis_label = "PBS" + self._bokeh_style_genome_xaxis(fig, contig) + + if show: # pragma: no cover + bokeh.plotting.show(fig) + return fig + + @_check_types + @doc( + summary=""" + Run and plot a PBS genome-wide scan with gene track to detect + lineage-specific selection in a focal population. + """, + ) + def plot_pbs_gwss( + self, + contig: base_params.contig, + window_size: pbs_params.window_size, + cohort1_query: base_params.sample_query, + cohort2_query: base_params.sample_query, + cohort3_query: base_params.sample_query, + normed: pbs_params.normed = pbs_params.normed_default, + sample_query_options: Optional[base_params.sample_query_options] = None, + sample_sets: Optional[base_params.sample_sets] = None, + site_mask: Optional[base_params.site_mask] = base_params.DEFAULT, + cohort_size: Optional[base_params.cohort_size] = pbs_params.cohort_size_default, + min_cohort_size: Optional[ + base_params.min_cohort_size + ] = pbs_params.min_cohort_size_default, + max_cohort_size: Optional[ + base_params.max_cohort_size + ] = pbs_params.max_cohort_size_default, + random_seed: base_params.random_seed = 42, + title: Optional[gplt_params.title] = None, + sizing_mode: gplt_params.sizing_mode = gplt_params.sizing_mode_default, + width: gplt_params.width = gplt_params.width_default, + track_height: gplt_params.track_height = 190, + genes_height: gplt_params.genes_height = gplt_params.genes_height_default, + show: gplt_params.show = True, + output_backend: gplt_params.output_backend = gplt_params.output_backend_default, + gene_labels: Optional[gplt_params.gene_labels] = None, + gene_labelset: Optional[gplt_params.gene_labelset] = None, + ) -> gplt_params.optional_figure: + # GWSS track. + fig1 = self.plot_pbs_gwss_track( + contig=contig, + window_size=window_size, + cohort1_query=cohort1_query, + cohort2_query=cohort2_query, + cohort3_query=cohort3_query, + normed=normed, + sample_query_options=sample_query_options, + sample_sets=sample_sets, + site_mask=site_mask, + cohort_size=cohort_size, + min_cohort_size=min_cohort_size, + max_cohort_size=max_cohort_size, + random_seed=random_seed, + title=title, + sizing_mode=sizing_mode, + width=width, + height=track_height, + show=False, + output_backend=output_backend, + ) + + fig1.xaxis.visible = False + + # Plot genes. + fig2 = self.plot_genes( + region=contig, + sizing_mode=sizing_mode, + width=width, + height=genes_height, + x_range=fig1.x_range, + show=False, + output_backend=output_backend, + gene_labels=gene_labels, + gene_labelset=gene_labelset, + ) + + # Combine plots into a single figure. + fig = bokeh.layouts.gridplot( + [fig1, fig2], + ncols=1, + toolbar_location="above", + merge_tools=True, + sizing_mode=sizing_mode, + toolbar_options=dict(active_inspect=None), + ) + + if show: # pragma: no cover + bokeh.plotting.show(fig) + return fig diff --git a/malariagen_data/anoph/pbs_params.py b/malariagen_data/anoph/pbs_params.py new file mode 100644 index 000000000..7212eff78 --- /dev/null +++ b/malariagen_data/anoph/pbs_params.py @@ -0,0 +1,42 @@ +"""Parameter definitions for PBS functions.""" + +from typing import Optional + +from typing_extensions import Annotated, TypeAlias + +from . import base_params + +# N.B., window size can mean different things for different functions +window_size: TypeAlias = Annotated[ + int, + "The size of windows (number of sites) used to calculate statistics within.", +] + +cohort_size_default: Optional[base_params.cohort_size] = None +min_cohort_size_default: base_params.min_cohort_size = 15 +max_cohort_size_default: base_params.max_cohort_size = 50 + +normed: TypeAlias = Annotated[ + bool, + """ + If True, normalise the PBS values by the sum of the divergence times. + This can help to identify extreme outlier loci. Default is True. + """, +] +normed_default: bool = True + +min_snps_threshold: TypeAlias = Annotated[ + int, + """ + Minimum number of SNP sites required for the PBS GWSS computation. If + fewer sites are available, a ValueError is raised. + """, +] + +window_adjustment_factor: TypeAlias = Annotated[ + int, + """ + If window_size is >= the number of available SNP sites, the window_size + is automatically adjusted to number_of_snps // window_adjustment_factor. + """, +] diff --git a/malariagen_data/anopheles.py b/malariagen_data/anopheles.py index 8342dbb88..2a2b600e5 100644 --- a/malariagen_data/anopheles.py +++ b/malariagen_data/anopheles.py @@ -39,6 +39,7 @@ from .anoph.to_vcf import SnpVcfExporter from .anoph.g123 import AnophelesG123Analysis from .anoph.fst import AnophelesFstAnalysis +from .anoph.pbs import AnophelesPbsAnalysis from .anoph.h12 import AnophelesH12Analysis from .anoph.h1x import AnophelesH1XAnalysis from .anoph.phenotypes import AnophelesPhenotypeData @@ -86,6 +87,7 @@ class AnophelesDataResource( AnophelesH12Analysis, AnophelesG123Analysis, AnophelesFstAnalysis, + AnophelesPbsAnalysis, AnophelesHetAnalysis, AnophelesHapFrequencyAnalysis, AnophelesDistanceAnalysis, diff --git a/notebooks/plot_pbs_gwss.ipynb b/notebooks/plot_pbs_gwss.ipynb new file mode 100644 index 000000000..c579cd65c --- /dev/null +++ b/notebooks/plot_pbs_gwss.ipynb @@ -0,0 +1,154 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "a31720f1", + "metadata": {}, + "outputs": [], + "source": [ + "import malariagen_data\n", + "\n", + "ag3 = malariagen_data.Ag3(\n", + " \"simplecache::gs://vo_agam_release_master_us_central1\",\n", + " simplecache=dict(cache_storage=\"../gcs_cache\"),\n", + " cohorts_analysis=\"20230516\",\n", + " results_cache=\"results_cache\",\n", + ")\n", + "ag3" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ece45863", + "metadata": {}, + "outputs": [], + "source": [ + "import malariagen_data\n", + "\n", + "af1 = malariagen_data.Af1(\n", + " \"simplecache::gs://vo_afun_release_master_us_central1\",\n", + " simplecache=dict(cache_storage=\"../gcs_cache\"),\n", + " cohorts_analysis=\"20230823\",\n", + " results_cache=\"results_cache\",\n", + ")\n", + "af1" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "30c285b1", + "metadata": {}, + "outputs": [], + "source": [ + "!rm -rf results_cache" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "effd3060", + "metadata": {}, + "outputs": [], + "source": [ + "ag3.plot_pbs_gwss_track(\n", + " contig=\"2L\",\n", + " window_size=10_000,\n", + " cohort1_query=\"cohort_admin2_year == 'ML-2_Kati_colu_2014'\",\n", + " cohort2_query=\"cohort_admin2_year == 'ML-2_Kati_gamb_2014'\",\n", + " cohort3_query=\"taxon == 'arabiensis'\",\n", + " site_mask=\"gamb_colu_arab\",\n", + " cohort_size=20,\n", + " sample_sets=\"3.0\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b63d02fe", + "metadata": {}, + "outputs": [], + "source": [ + "ag3.plot_pbs_gwss(\n", + " contig=\"2L\",\n", + " window_size=10_000,\n", + " cohort1_query=\"cohort_admin2_year == 'ML-2_Kati_colu_2014'\",\n", + " cohort2_query=\"cohort_admin2_year == 'ML-2_Kati_gamb_2014'\",\n", + " cohort3_query=\"taxon == 'arabiensis'\",\n", + " site_mask=\"gamb_colu_arab\",\n", + " cohort_size=20,\n", + " sample_sets=\"3.0\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f77bbae1", + "metadata": {}, + "outputs": [], + "source": [ + "ag3.plot_pbs_gwss(\n", + " contig=\"2L\",\n", + " window_size=10_000,\n", + " cohort1_query=\"cohort_admin2_year == 'ML-2_Kati_colu_2014'\",\n", + " cohort2_query=\"cohort_admin2_year == 'ML-2_Kati_gamb_2014'\",\n", + " cohort3_query=\"taxon == 'arabiensis'\",\n", + " normed=False,\n", + " site_mask=\"gamb_colu_arab\",\n", + " cohort_size=20,\n", + " sample_sets=\"3.0\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1b87b2f2", + "metadata": {}, + "outputs": [], + "source": [ + "af1.plot_pbs_gwss(\n", + " contig=\"X\",\n", + " window_size=20_000,\n", + " cohort1_query=\"cohort_admin1_year == 'KE-03_fune_2016'\",\n", + " cohort2_query=\"cohort_admin1_year == 'MZ-L_fune_2016'\",\n", + " cohort3_query=\"country == 'Ghana'\",\n", + " cohort_size=20,\n", + " sample_sets=\"1.0\",\n", + ")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "state": {}, + "version_major": 2, + "version_minor": 0 + } + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tests/anoph/test_pbs.py b/tests/anoph/test_pbs.py new file mode 100644 index 000000000..186806b23 --- /dev/null +++ b/tests/anoph/test_pbs.py @@ -0,0 +1,227 @@ +import pytest +from pytest_cases import parametrize_with_cases +import numpy as np +import bokeh.models + +from malariagen_data import af1 as _af1 +from malariagen_data import ag3 as _ag3 +from malariagen_data import adir1 as _adir1 + +from malariagen_data.anoph.pbs import AnophelesPbsAnalysis + + +@pytest.fixture +def ag3_sim_api(ag3_sim_fixture): + return AnophelesPbsAnalysis( + url=ag3_sim_fixture.url, + public_url=ag3_sim_fixture.url, + config_path=_ag3.CONFIG_PATH, + major_version_number=_ag3.MAJOR_VERSION_NUMBER, + major_version_path=_ag3.MAJOR_VERSION_PATH, + pre=True, + aim_metadata_dtype={ + "aim_species_fraction_arab": "float64", + "aim_species_fraction_colu": "float64", + "aim_species_fraction_colu_no2l": "float64", + "aim_species_gambcolu_arabiensis": object, + "aim_species_gambiae_coluzzii": object, + "aim_species": object, + }, + gff_gene_type="gene", + gff_gene_name_attribute="Name", + gff_default_attributes=("ID", "Parent", "Name", "description"), + default_site_mask="gamb_colu_arab", + results_cache=ag3_sim_fixture.results_cache_path.as_posix(), + taxon_colors=_ag3.TAXON_COLORS, + virtual_contigs=_ag3.VIRTUAL_CONTIGS, + ) + + +@pytest.fixture +def af1_sim_api(af1_sim_fixture): + return AnophelesPbsAnalysis( + url=af1_sim_fixture.url, + public_url=af1_sim_fixture.url, + config_path=_af1.CONFIG_PATH, + major_version_number=_af1.MAJOR_VERSION_NUMBER, + major_version_path=_af1.MAJOR_VERSION_PATH, + pre=False, + gff_gene_type="protein_coding_gene", + gff_gene_name_attribute="Note", + gff_default_attributes=("ID", "Parent", "Note", "description"), + default_site_mask="funestus", + results_cache=af1_sim_fixture.results_cache_path.as_posix(), + taxon_colors=_af1.TAXON_COLORS, + ) + + +@pytest.fixture +def adir1_sim_api(adir1_sim_fixture): + return AnophelesPbsAnalysis( + url=adir1_sim_fixture.url, + public_url=adir1_sim_fixture.url, + config_path=_adir1.CONFIG_PATH, + major_version_number=_adir1.MAJOR_VERSION_NUMBER, + major_version_path=_adir1.MAJOR_VERSION_PATH, + pre=False, + gff_gene_type="protein_coding_gene", + gff_gene_name_attribute="Note", + gff_default_attributes=("ID", "Parent", "Note", "description"), + default_site_mask="dirus", + results_cache=adir1_sim_fixture.results_cache_path.as_posix(), + taxon_colors=_adir1.TAXON_COLORS, + ) + + +# N.B., here we use pytest_cases to parametrize tests. Each +# function whose name begins with "case_" defines a set of +# inputs to the test functions. See the documentation for +# pytest_cases for more information, e.g.: +# +# https://smarie.github.io/python-pytest-cases/#basic-usage +# +# We use this approach here because we want to use fixtures +# as test parameters, which is otherwise hard to do with +# pytest alone. + + +def case_ag3_sim(ag3_sim_fixture, ag3_sim_api): + return ag3_sim_fixture, ag3_sim_api + + +def case_af1_sim(af1_sim_fixture, af1_sim_api): + return af1_sim_fixture, af1_sim_api + + +def case_adir1_sim(adir1_sim_fixture, adir1_sim_api): + return adir1_sim_fixture, adir1_sim_api + + +@parametrize_with_cases("fixture,api", cases=".") +def test_pbs_gwss(fixture, api: AnophelesPbsAnalysis): + # Set up test parameters - need 3 distinct cohorts. + all_sample_sets = api.sample_sets()["sample_set"].to_list() + all_countries = api.sample_metadata()["country"].dropna().unique().tolist() + if len(all_countries) < 3: + pytest.skip("Not enough distinct countries for PBS test (need 3).") + countries = np.random.choice(all_countries, size=3, replace=False).tolist() + cohort1_query = f"country == {countries[0]!r}" + cohort2_query = f"country == {countries[1]!r}" + cohort3_query = f"country == {countries[2]!r}" + pbs_params = dict( + contig=str(np.random.choice(api.contigs)), + sample_sets=all_sample_sets, + cohort1_query=cohort1_query, + cohort2_query=cohort2_query, + cohort3_query=cohort3_query, + site_mask=str(np.random.choice(api.site_mask_ids)), + window_size=int(np.random.randint(10, 51)), + min_cohort_size=1, + ) + + # Run main gwss function under test. + x, pbs = api.pbs_gwss(**pbs_params) + + # Check results. + assert isinstance(x, np.ndarray) + assert isinstance(pbs, np.ndarray) + assert x.ndim == 1 + assert pbs.ndim == 1 + assert x.shape == pbs.shape + assert x.dtype.kind == "f" + assert pbs.dtype.kind == "f" + + # Check plotting functions. + fig = api.plot_pbs_gwss_track(**pbs_params, show=False) + assert isinstance(fig, bokeh.models.Plot) + fig = api.plot_pbs_gwss(**pbs_params, show=False) + assert isinstance(fig, bokeh.models.GridPlot) + + +@parametrize_with_cases("fixture,api", cases=".") +def test_pbs_gwss_normed(fixture, api: AnophelesPbsAnalysis): + # Test both normed=True and normed=False. + all_sample_sets = api.sample_sets()["sample_set"].to_list() + all_countries = api.sample_metadata()["country"].dropna().unique().tolist() + if len(all_countries) < 3: + pytest.skip("Not enough distinct countries for PBS test (need 3).") + countries = np.random.choice(all_countries, size=3, replace=False).tolist() + cohort1_query = f"country == {countries[0]!r}" + cohort2_query = f"country == {countries[1]!r}" + cohort3_query = f"country == {countries[2]!r}" + common_params = dict( + contig=str(np.random.choice(api.contigs)), + sample_sets=all_sample_sets, + cohort1_query=cohort1_query, + cohort2_query=cohort2_query, + cohort3_query=cohort3_query, + site_mask=str(np.random.choice(api.site_mask_ids)), + window_size=int(np.random.randint(10, 51)), + min_cohort_size=1, + ) + + # Run with normed=True. + x_normed, pbs_normed = api.pbs_gwss(**common_params, normed=True) + assert isinstance(x_normed, np.ndarray) + assert isinstance(pbs_normed, np.ndarray) + assert x_normed.shape == pbs_normed.shape + + # Run with normed=False. + x_unnormed, pbs_unnormed = api.pbs_gwss(**common_params, normed=False) + assert isinstance(x_unnormed, np.ndarray) + assert isinstance(pbs_unnormed, np.ndarray) + assert x_unnormed.shape == pbs_unnormed.shape + + +@parametrize_with_cases("fixture,api", cases=".") +def test_pbs_gwss_window_size_too_large(fixture, api: AnophelesPbsAnalysis): + # When window_size exceeds available SNPs, a UserWarning must be issued and + # the function must still return a valid result using the adjusted window_size. + all_sample_sets = api.sample_sets()["sample_set"].to_list() + all_countries = api.sample_metadata()["country"].dropna().unique().tolist() + if len(all_countries) < 3: + pytest.skip("Not enough distinct countries for PBS test (need 3).") + countries = np.random.choice(all_countries, size=3, replace=False).tolist() + cohort1_query = f"country == {countries[0]!r}" + cohort2_query = f"country == {countries[1]!r}" + cohort3_query = f"country == {countries[2]!r}" + with pytest.warns(UserWarning, match="window_size"): + x, pbs = api.pbs_gwss( + contig=str(np.random.choice(api.contigs)), + sample_sets=all_sample_sets, + cohort1_query=cohort1_query, + cohort2_query=cohort2_query, + cohort3_query=cohort3_query, + site_mask=str(np.random.choice(api.site_mask_ids)), + window_size=10_000_000, # far larger than any fixture SNP count + min_cohort_size=1, + ) + assert isinstance(x, np.ndarray) + assert isinstance(pbs, np.ndarray) + assert len(x) > 0 + assert x.shape == pbs.shape + + +@parametrize_with_cases("fixture,api", cases=".") +def test_pbs_gwss_too_few_snps(fixture, api: AnophelesPbsAnalysis): + # When min_snps_threshold exceeds available SNPs, a ValueError must be raised. + all_sample_sets = api.sample_sets()["sample_set"].to_list() + all_countries = api.sample_metadata()["country"].dropna().unique().tolist() + if len(all_countries) < 3: + pytest.skip("Not enough distinct countries for PBS test (need 3).") + countries = np.random.choice(all_countries, size=3, replace=False).tolist() + cohort1_query = f"country == {countries[0]!r}" + cohort2_query = f"country == {countries[1]!r}" + cohort3_query = f"country == {countries[2]!r}" + with pytest.raises(ValueError, match="Too few SNP sites"): + api.pbs_gwss( + contig=str(np.random.choice(api.contigs)), + sample_sets=all_sample_sets, + cohort1_query=cohort1_query, + cohort2_query=cohort2_query, + cohort3_query=cohort3_query, + site_mask=str(np.random.choice(api.site_mask_ids)), + window_size=100, + min_cohort_size=1, + min_snps_threshold=10_000_000, + )