@@ -1452,6 +1452,13 @@ def _snp_allele_counts(
14521452 Compute SNP allele counts. This returns the number of times each
14531453 SNP allele was observed in the selected samples.
14541454 """ ,
1455+ parameters = dict (
1456+ return_dataset = """
1457+ If True, return an xarray dataset containing SNP calls with
1458+ `variant_allele_count` added as an extra data variable. If False
1459+ (default), return a numpy array of allele counts.
1460+ """
1461+ ),
14551462 returns = """
14561463 A numpy array of shape (n_variants, 4), where the first column has
14571464 the reference allele (0) counts, the second column has the first
@@ -1481,7 +1488,26 @@ def snp_allele_counts(
14811488 random_seed : base_params .random_seed = 42 ,
14821489 inline_array : base_params .inline_array = base_params .inline_array_default ,
14831490 chunks : base_params .chunks = base_params .native_chunks ,
1484- ) -> np .ndarray :
1491+ return_dataset : bool = False ,
1492+ ) -> Any :
1493+ if return_dataset :
1494+ ds = self ._snp_calls_with_allele_counts (
1495+ region = region ,
1496+ sample_sets = sample_sets ,
1497+ sample_query = sample_query ,
1498+ sample_query_options = sample_query_options ,
1499+ sample_indices = sample_indices ,
1500+ site_mask = site_mask ,
1501+ site_class = site_class ,
1502+ cohort_size = cohort_size ,
1503+ min_cohort_size = min_cohort_size ,
1504+ max_cohort_size = max_cohort_size ,
1505+ random_seed = random_seed ,
1506+ inline_array = inline_array ,
1507+ chunks = chunks ,
1508+ )
1509+ return ds
1510+
14851511 # Change this name if you ever change the behaviour of this function,
14861512 # to invalidate any previously cached data.
14871513 name = "snp_allele_counts_v2"
@@ -1555,6 +1581,52 @@ def snp_allele_counts(
15551581 ac = results ["ac" ]
15561582 return ac
15571583
1584+ def _snp_calls_with_allele_counts (
1585+ self ,
1586+ * ,
1587+ region ,
1588+ sample_sets ,
1589+ sample_query ,
1590+ sample_query_options ,
1591+ sample_indices ,
1592+ site_mask ,
1593+ site_class ,
1594+ cohort_size ,
1595+ min_cohort_size ,
1596+ max_cohort_size ,
1597+ random_seed ,
1598+ inline_array ,
1599+ chunks ,
1600+ ) -> xr .Dataset :
1601+ ds = self .snp_calls (
1602+ region = region ,
1603+ sample_sets = sample_sets ,
1604+ sample_query = sample_query ,
1605+ sample_query_options = sample_query_options ,
1606+ sample_indices = sample_indices ,
1607+ site_mask = site_mask ,
1608+ site_class = site_class ,
1609+ cohort_size = cohort_size ,
1610+ min_cohort_size = min_cohort_size ,
1611+ max_cohort_size = max_cohort_size ,
1612+ random_seed = random_seed ,
1613+ inline_array = inline_array ,
1614+ chunks = chunks ,
1615+ )
1616+
1617+ gt = allel .GenotypeDaskArray (ds ["call_genotype" ].data )
1618+ ac = gt .count_alleles (max_allele = 3 )
1619+ with self ._dask_progress (desc = "Compute SNP allele counts" ):
1620+ ac = ac .compute ().values
1621+
1622+ ds = ds .assign (
1623+ variant_allele_count = (
1624+ ds ["variant_allele" ].dims ,
1625+ ac ,
1626+ )
1627+ )
1628+ return ds
1629+
15581630 @_check_types
15591631 @doc (
15601632 summary = """
@@ -1897,8 +1969,8 @@ def biallelic_snp_calls(
18971969 sample_query = sample_query , sample_indices = sample_indices
18981970 )
18991971
1900- # Perform an allele count .
1901- ac = self .snp_allele_counts (
1972+ # Access SNP calls with allele counts in a single pass .
1973+ ds = self ._snp_calls_with_allele_counts (
19021974 region = region ,
19031975 sample_sets = sample_sets ,
19041976 sample_query = sample_query ,
@@ -1913,6 +1985,7 @@ def biallelic_snp_calls(
19131985 inline_array = inline_array ,
19141986 chunks = chunks ,
19151987 )
1988+ ac = ds ["variant_allele_count" ].values
19161989
19171990 # Locate biallelic SNPs.
19181991 loc_bi = allel .AlleleCountsArray (ac ).is_biallelic ()
@@ -1921,23 +1994,6 @@ def biallelic_snp_calls(
19211994 ac_bi = ac [loc_bi ]
19221995 allele_mapping = _trim_alleles (ac_bi )
19231996
1924- # Set up SNP calls.
1925- ds = self .snp_calls (
1926- region = region ,
1927- sample_sets = sample_sets ,
1928- sample_query = sample_query ,
1929- sample_query_options = sample_query_options ,
1930- sample_indices = sample_indices ,
1931- site_mask = site_mask ,
1932- site_class = site_class ,
1933- cohort_size = cohort_size ,
1934- min_cohort_size = min_cohort_size ,
1935- max_cohort_size = max_cohort_size ,
1936- random_seed = random_seed ,
1937- inline_array = inline_array ,
1938- chunks = chunks ,
1939- )
1940-
19411997 with self ._spinner ("Prepare biallelic SNP calls" ):
19421998 # Subset to biallelic sites.
19431999 ds_bi = _dask_compress_dataset (ds , indexer = loc_bi , dim = "variants" )
@@ -2032,13 +2088,22 @@ def biallelic_snp_calls(
20322088 @_check_types
20332089 @doc (
20342090 summary = "Load biallelic SNP genotypes." ,
2035- returns = dict (
2036- gn = """
2037- An array of shape (variants, samples) where each value counts the
2038- number of alternate alleles per genotype call.
2039- """ ,
2040- samples = "Sample identifiers." ,
2091+ parameters = dict (
2092+ return_dataset = """
2093+ If True, return an xarray dataset with `call_diplotype` plus
2094+ `sample_id`, `variant_position`, and `variant_contig`. If False
2095+ (default), return a tuple `(gn, samples)` for backward
2096+ compatibility.
2097+ """
20412098 ),
2099+ returns = """
2100+ If `return_dataset` is False (default), return `(gn, samples)`, where
2101+ `gn` is an array of shape `(variants, samples)` counting alternate
2102+ alleles per genotype call and `samples` contains sample identifiers.
2103+ If `return_dataset` is True, return a dataset containing
2104+ `call_diplotype` with dimensions `(variants, samples)`, plus
2105+ `sample_id`, `variant_position`, and `variant_contig`.
2106+ """ ,
20422107 )
20432108 def biallelic_diplotypes (
20442109 self ,
@@ -2059,7 +2124,8 @@ def biallelic_diplotypes(
20592124 thin_offset : base_params .thin_offset = 0 ,
20602125 inline_array : base_params .inline_array = base_params .inline_array_default ,
20612126 chunks : base_params .chunks = base_params .native_chunks ,
2062- ) -> Tuple [np .ndarray , np .ndarray ]:
2127+ return_dataset : bool = False ,
2128+ ) -> Any :
20632129 # Change this name if you ever change the behaviour of this function, to
20642130 # invalidate any previously cached data.
20652131 name = "biallelic_diplotypes_v2"
@@ -2147,6 +2213,22 @@ def biallelic_diplotypes(
21472213 gn = results ["gn" ]
21482214 samples = results ["samples" ]
21492215
2216+ if return_dataset :
2217+ ds = xr .Dataset (
2218+ coords = {
2219+ "sample_id" : ("samples" , samples ),
2220+ "variant_position" : ("variants" , results ["variant_position" ]),
2221+ "variant_contig" : ("variants" , results ["variant_contig" ]),
2222+ },
2223+ data_vars = {
2224+ "call_diplotype" : (
2225+ ("variants" , "samples" ),
2226+ gn ,
2227+ )
2228+ },
2229+ )
2230+ return ds
2231+
21502232 return gn , samples
21512233
21522234 def _biallelic_diplotypes (
@@ -2193,6 +2275,8 @@ def _biallelic_diplotypes(
21932275
21942276 # Load sample IDs
21952277 samples = ds ["sample_id" ].values .astype ("U" )
2278+ variant_position = ds ["variant_position" ].values
2279+ variant_contig = ds ["variant_contig" ].values
21962280
21972281 # Compute diplotypes as the number of alt alleles per genotype call.
21982282 # with missing calls coded as -127.
@@ -2203,4 +2287,9 @@ def _biallelic_diplotypes(
22032287 missing = np .all (ds ["call_genotype" ].values == - 1 , axis = 2 )
22042288 gn [missing ] = - 127
22052289
2206- return dict (samples = samples , gn = gn )
2290+ return dict (
2291+ samples = samples ,
2292+ gn = gn ,
2293+ variant_position = variant_position ,
2294+ variant_contig = variant_contig ,
2295+ )
0 commit comments