Skip to content

Commit eed5276

Browse files
committed
remove extra snp_calls
1 parent 7ae95db commit eed5276

2 files changed

Lines changed: 63 additions & 24 deletions

File tree

malariagen_data/anoph/snp_data.py

Lines changed: 59 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1440,7 +1440,15 @@ def _snp_allele_counts(
14401440
with self._dask_progress(desc="Compute SNP allele counts"):
14411441
ac = ac.compute()
14421442

1443-
results = dict(ac=ac.values)
1443+
# Cache variant metadata alongside allele counts so that
1444+
# snp_allele_counts(return_dataset=True) can reconstruct a
1445+
# Dataset without a redundant snp_calls() invocation.
1446+
results = dict(
1447+
ac=ac.values,
1448+
variant_position=ds_snps["variant_position"].values,
1449+
variant_contig=ds_snps["variant_contig"].values,
1450+
variant_allele=ds_snps["variant_allele"].values,
1451+
)
14441452

14451453
return results
14461454

@@ -1481,7 +1489,9 @@ def snp_allele_counts(
14811489
chunks: base_params.chunks = base_params.native_chunks,
14821490
return_dataset: base_params.return_dataset = False,
14831491
) -> Any:
1484-
name = "snp_allele_counts_v2"
1492+
# Bumped to v3 to include variant metadata in cached results,
1493+
# enabling Dataset reconstruction without extra snp_calls().
1494+
name = "snp_allele_counts_v3"
14851495

14861496
base_params._validate_sample_selection_params(
14871497
sample_query=sample_query, sample_indices=sample_indices
@@ -1547,24 +1557,23 @@ def snp_allele_counts(
15471557
ac = results["ac"]
15481558

15491559
if return_dataset:
1550-
ds = self.snp_calls(
1551-
region=params["region"],
1552-
sample_sets=params["sample_sets"],
1553-
sample_indices=params["sample_indices"],
1554-
site_mask=params["site_mask"],
1555-
site_class=site_class,
1556-
cohort_size=cohort_size,
1557-
min_cohort_size=min_cohort_size,
1558-
max_cohort_size=max_cohort_size,
1559-
random_seed=random_seed,
1560-
inline_array=inline_array,
1561-
chunks=chunks,
1562-
)
1563-
ds = ds.assign(
1564-
variant_allele_count=(
1565-
ds["variant_allele"].dims,
1566-
ac,
1567-
)
1560+
# Reconstruct the Dataset from cached arrays — no extra
1561+
# snp_calls() invocation required.
1562+
ds = xr.Dataset(
1563+
coords={
1564+
"variant_position": (DIM_VARIANT, results["variant_position"]),
1565+
"variant_contig": (DIM_VARIANT, results["variant_contig"]),
1566+
},
1567+
data_vars={
1568+
"variant_allele": (
1569+
(DIM_VARIANT, DIM_ALLELE),
1570+
results["variant_allele"],
1571+
),
1572+
"variant_allele_count": (
1573+
(DIM_VARIANT, DIM_ALLELE),
1574+
ac,
1575+
),
1576+
},
15681577
)
15691578
return ds
15701579

@@ -1911,7 +1920,9 @@ def biallelic_snp_calls(
19111920
sample_query=sample_query, sample_indices=sample_indices
19121921
)
19131922

1914-
ds = self.snp_allele_counts(
1923+
# Get the full SNP calls dataset (genotypes, sample_id, variant info).
1924+
# N.B., snp_calls() uses an LRU cache so this is efficient.
1925+
ds = self.snp_calls(
19151926
region=region,
19161927
sample_sets=sample_sets,
19171928
sample_query=sample_query,
@@ -1925,9 +1936,33 @@ def biallelic_snp_calls(
19251936
random_seed=random_seed,
19261937
inline_array=inline_array,
19271938
chunks=chunks,
1928-
return_dataset=True,
19291939
)
1930-
ac = ds["variant_allele_count"].values
1940+
1941+
# Get allele counts (uses the results cache, so no redundant
1942+
# genotype computation even if snp_calls was already called).
1943+
ac = self.snp_allele_counts(
1944+
region=region,
1945+
sample_sets=sample_sets,
1946+
sample_query=sample_query,
1947+
sample_query_options=sample_query_options,
1948+
sample_indices=sample_indices,
1949+
site_mask=site_mask,
1950+
site_class=site_class,
1951+
cohort_size=cohort_size,
1952+
min_cohort_size=min_cohort_size,
1953+
max_cohort_size=max_cohort_size,
1954+
random_seed=random_seed,
1955+
inline_array=inline_array,
1956+
chunks=chunks,
1957+
)
1958+
1959+
# Attach allele counts to the dataset.
1960+
ds = ds.assign(
1961+
variant_allele_count=(
1962+
ds["variant_allele"].dims,
1963+
ac,
1964+
)
1965+
)
19311966

19321967
# Locate biallelic SNPs.
19331968
loc_bi = allel.AlleleCountsArray(ac).is_biallelic()
@@ -2185,7 +2220,7 @@ def _biallelic_diplotypes(
21852220
thin_offset: base_params.thin_offset,
21862221
inline_array: base_params.inline_array,
21872222
chunks: base_params.chunks,
2188-
) -> Dict[str, np.ndarray]:
2223+
) -> xr.Dataset:
21892224
# Note: this function uses sample_indices and should not expect a sample_query.
21902225

21912226
# Access biallelic SNPs.

tests/anoph/test_snp_data.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1158,6 +1158,10 @@ def check_snp_allele_counts(
11581158
assert isinstance(ds_ac, xr.Dataset)
11591159
assert "variant_allele_count" in ds_ac
11601160
assert_array_equal(ds_ac["variant_allele_count"].values, ac)
1161+
# Verify variant metadata is included.
1162+
assert "variant_position" in ds_ac.coords
1163+
assert "variant_contig" in ds_ac.coords
1164+
assert "variant_allele" in ds_ac
11611165

11621166

11631167
@parametrize_with_cases(

0 commit comments

Comments
 (0)