|
2 | 2 |
|
3 | 3 | import dask |
4 | 4 | import pandas as pd # type: ignore |
| 5 | +from pandas import CategoricalDtype |
| 6 | +import numpy as np # type: ignore |
| 7 | +import allel # type: ignore |
5 | 8 | import plotly.express as px # type: ignore |
6 | 9 |
|
7 | 10 | import malariagen_data |
8 | 11 | from .anopheles import AnophelesDataResource |
9 | 12 |
|
| 13 | +from numpydoc_decorator import doc |
| 14 | +from .util import check_types, _karyotype_tags_n_alt |
| 15 | +from .anoph import base_params |
| 16 | +from typing import Optional, Literal, Annotated, TypeAlias |
| 17 | + |
10 | 18 | # silence dask performance warnings |
11 | 19 | dask.config.set(**{"array.slicing.split_native_chunks": False}) # type: ignore |
12 | 20 |
|
@@ -75,6 +83,12 @@ def _setup_aim_palettes(): |
75 | 83 | } |
76 | 84 |
|
77 | 85 |
|
| 86 | +inversion_param: TypeAlias = Annotated[ |
| 87 | + Literal["2La", "2Rb", "2Rc_gam", "2Rc_col", "2Rd", "2Rj"], |
| 88 | + "Name of inversion to infer karyotype for.", |
| 89 | +] |
| 90 | + |
| 91 | + |
78 | 92 | class Ag3(AnophelesDataResource): |
79 | 93 | """Provides access to data from Ag3.x releases. |
80 | 94 |
|
@@ -341,3 +355,74 @@ def _results_cache_add_analysis_params(self, params): |
341 | 355 | super()._results_cache_add_analysis_params(params) |
342 | 356 | # override parent class to add AIM analysis |
343 | 357 | params["aim_analysis"] = self._aim_analysis |
| 358 | + |
| 359 | + @check_types |
| 360 | + @doc( |
| 361 | + summary="Load tag SNPs for a given inversion in Ag.", |
| 362 | + ) |
| 363 | + def load_inversion_tags(self, inversion: inversion_param) -> pd.DataFrame: |
| 364 | + # needs to be modified depending on where we are hosting |
| 365 | + import importlib.resources |
| 366 | + from . import resources |
| 367 | + |
| 368 | + with importlib.resources.path(resources, "karyotype_tag_snps.csv") as path: |
| 369 | + df_tag_snps = pd.read_csv(path, sep=",") |
| 370 | + return df_tag_snps.query(f"inversion == '{inversion}'").reset_index() |
| 371 | + |
| 372 | + @check_types |
| 373 | + @doc( |
| 374 | + summary="Infer karyotype from tag SNPs for a given inversion in Ag.", |
| 375 | + ) |
| 376 | + def karyotype( |
| 377 | + self, |
| 378 | + inversion: inversion_param, |
| 379 | + sample_sets: Optional[base_params.sample_sets] = None, |
| 380 | + sample_query: Optional[base_params.sample_query] = None, |
| 381 | + ) -> pd.DataFrame: |
| 382 | + # load tag snp data |
| 383 | + df_tagsnps = self.load_inversion_tags(inversion=inversion) |
| 384 | + inversion_pos = df_tagsnps["position"] |
| 385 | + inversion_alts = df_tagsnps["alt_allele"] |
| 386 | + contig = inversion[0:2] |
| 387 | + |
| 388 | + # get snp calls for inversion region |
| 389 | + start, end = np.min(inversion_pos), np.max(inversion_pos) |
| 390 | + region = f"{contig}:{start}-{end}" |
| 391 | + |
| 392 | + ds_snps = self.snp_calls( |
| 393 | + region=region, sample_sets=sample_sets, sample_query=sample_query |
| 394 | + ) |
| 395 | + geno = allel.GenotypeDaskArray(ds_snps["call_genotype"].data) |
| 396 | + pos = allel.SortedIndex(ds_snps["variant_position"].values) |
| 397 | + samples = ds_snps["sample_id"].values |
| 398 | + alts = ds_snps["variant_allele"].values.astype(str) |
| 399 | + |
| 400 | + # subset to position of inversion tags |
| 401 | + mask = pos.locate_intersection(inversion_pos)[0] |
| 402 | + alts = alts[mask] |
| 403 | + geno = geno.compress(mask, axis=0).compute() |
| 404 | + |
| 405 | + with self._spinner("Inferring karyotype from tag SNPs"): |
| 406 | + gn_alt = _karyotype_tags_n_alt( |
| 407 | + gt=geno, alts=alts, inversion_alts=inversion_alts |
| 408 | + ) |
| 409 | + is_called = geno.is_called() |
| 410 | + |
| 411 | + # calculate mean genotype for each sample whilst masking missing calls |
| 412 | + av_gts = np.mean(np.ma.MaskedArray(gn_alt, mask=~is_called), axis=0) |
| 413 | + total_sites = np.sum(is_called, axis=0) |
| 414 | + |
| 415 | + df = pd.DataFrame( |
| 416 | + { |
| 417 | + "sample_id": samples, |
| 418 | + "inversion": inversion, |
| 419 | + f"karyotype_{inversion}_mean": av_gts, |
| 420 | + # round the genotypes then convert to int |
| 421 | + f"karyotype_{inversion}": av_gts.round().astype(int), |
| 422 | + "total_tag_snps": total_sites, |
| 423 | + }, |
| 424 | + ) |
| 425 | + kt_dtype = CategoricalDtype(categories=[0, 1, 2], ordered=True) |
| 426 | + df[f"karyotype_{inversion}"] = df[f"karyotype_{inversion}"].astype(kt_dtype) |
| 427 | + |
| 428 | + return df |
0 commit comments