@@ -1440,7 +1440,15 @@ def _snp_allele_counts(
14401440 with self ._dask_progress (desc = "Compute SNP allele counts" ):
14411441 ac = ac .compute ()
14421442
1443- results = dict (ac = ac .values )
1443+ # Cache variant metadata alongside allele counts so that
1444+ # snp_allele_counts(return_dataset=True) can reconstruct a
1445+ # Dataset without a redundant snp_calls() invocation.
1446+ results = dict (
1447+ ac = ac .values ,
1448+ variant_position = ds_snps ["variant_position" ].values ,
1449+ variant_contig = ds_snps ["variant_contig" ].values ,
1450+ variant_allele = ds_snps ["variant_allele" ].values ,
1451+ )
14441452
14451453 return results
14461454
@@ -1481,7 +1489,9 @@ def snp_allele_counts(
14811489 chunks : base_params .chunks = base_params .native_chunks ,
14821490 return_dataset : base_params .return_dataset = False ,
14831491 ) -> Any :
1484- name = "snp_allele_counts_v2"
1492+ # Bumped to v3 to include variant metadata in cached results,
1493+ # enabling Dataset reconstruction without extra snp_calls().
1494+ name = "snp_allele_counts_v3"
14851495
14861496 base_params ._validate_sample_selection_params (
14871497 sample_query = sample_query , sample_indices = sample_indices
@@ -1547,24 +1557,23 @@ def snp_allele_counts(
15471557 ac = results ["ac" ]
15481558
15491559 if return_dataset :
1550- ds = self .snp_calls (
1551- region = params ["region" ],
1552- sample_sets = params ["sample_sets" ],
1553- sample_indices = params ["sample_indices" ],
1554- site_mask = params ["site_mask" ],
1555- site_class = site_class ,
1556- cohort_size = cohort_size ,
1557- min_cohort_size = min_cohort_size ,
1558- max_cohort_size = max_cohort_size ,
1559- random_seed = random_seed ,
1560- inline_array = inline_array ,
1561- chunks = chunks ,
1562- )
1563- ds = ds .assign (
1564- variant_allele_count = (
1565- ds ["variant_allele" ].dims ,
1566- ac ,
1567- )
1560+ # Reconstruct the Dataset from cached arrays — no extra
1561+ # snp_calls() invocation required.
1562+ ds = xr .Dataset (
1563+ coords = {
1564+ "variant_position" : (DIM_VARIANT , results ["variant_position" ]),
1565+ "variant_contig" : (DIM_VARIANT , results ["variant_contig" ]),
1566+ },
1567+ data_vars = {
1568+ "variant_allele" : (
1569+ (DIM_VARIANT , DIM_ALLELE ),
1570+ results ["variant_allele" ],
1571+ ),
1572+ "variant_allele_count" : (
1573+ (DIM_VARIANT , DIM_ALLELE ),
1574+ ac ,
1575+ ),
1576+ },
15681577 )
15691578 return ds
15701579
@@ -1911,7 +1920,9 @@ def biallelic_snp_calls(
19111920 sample_query = sample_query , sample_indices = sample_indices
19121921 )
19131922
1914- ds = self .snp_allele_counts (
1923+ # Get the full SNP calls dataset (genotypes, sample_id, variant info).
1924+ # N.B., snp_calls() uses an LRU cache so this is efficient.
1925+ ds = self .snp_calls (
19151926 region = region ,
19161927 sample_sets = sample_sets ,
19171928 sample_query = sample_query ,
@@ -1925,9 +1936,33 @@ def biallelic_snp_calls(
19251936 random_seed = random_seed ,
19261937 inline_array = inline_array ,
19271938 chunks = chunks ,
1928- return_dataset = True ,
19291939 )
1930- ac = ds ["variant_allele_count" ].values
1940+
1941+ # Get allele counts (uses the results cache, so no redundant
1942+ # genotype computation even if snp_calls was already called).
1943+ ac = self .snp_allele_counts (
1944+ region = region ,
1945+ sample_sets = sample_sets ,
1946+ sample_query = sample_query ,
1947+ sample_query_options = sample_query_options ,
1948+ sample_indices = sample_indices ,
1949+ site_mask = site_mask ,
1950+ site_class = site_class ,
1951+ cohort_size = cohort_size ,
1952+ min_cohort_size = min_cohort_size ,
1953+ max_cohort_size = max_cohort_size ,
1954+ random_seed = random_seed ,
1955+ inline_array = inline_array ,
1956+ chunks = chunks ,
1957+ )
1958+
1959+ # Attach allele counts to the dataset.
1960+ ds = ds .assign (
1961+ variant_allele_count = (
1962+ ds ["variant_allele" ].dims ,
1963+ ac ,
1964+ )
1965+ )
19311966
19321967 # Locate biallelic SNPs.
19331968 loc_bi = allel .AlleleCountsArray (ac ).is_biallelic ()
@@ -2185,7 +2220,7 @@ def _biallelic_diplotypes(
21852220 thin_offset : base_params .thin_offset ,
21862221 inline_array : base_params .inline_array ,
21872222 chunks : base_params .chunks ,
2188- ) -> Dict [ str , np . ndarray ] :
2223+ ) -> xr . Dataset :
21892224 # Note: this function uses sample_indices and should not expect a sample_query.
21902225
21912226 # Access biallelic SNPs.
0 commit comments