Skip to content

Commit f3afe7e

Browse files
committed
Add dataset-return options for SNP counts/diplotypes and remove duplicate SNP calls
1 parent e48d75b commit f3afe7e

4 files changed

Lines changed: 154 additions & 30 deletions

File tree

malariagen_data/anoph/distance.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ def _biallelic_diplotype_pairwise_distances(
195195
max_missing_an,
196196
):
197197
# Compute diplotypes.
198-
gn, samples = self.biallelic_diplotypes(
198+
ds = self.biallelic_diplotypes(
199199
region=region,
200200
sample_sets=sample_sets,
201201
sample_indices=sample_indices,
@@ -211,7 +211,10 @@ def _biallelic_diplotype_pairwise_distances(
211211
min_minor_ac=min_minor_ac,
212212
n_snps=n_snps,
213213
thin_offset=thin_offset,
214+
return_dataset=True,
214215
)
216+
gn = ds["call_diplotype"].values
217+
samples = ds["sample_id"].values.astype("U")
215218

216219
# Record number of SNPs used.
217220
n_snps = gn.shape[0]

malariagen_data/anoph/pca.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def _pca(
207207
inline_array,
208208
):
209209
# Load diplotypes.
210-
gn, samples = self.biallelic_diplotypes(
210+
ds_diplotypes = self.biallelic_diplotypes(
211211
region=region,
212212
n_snps=n_snps,
213213
thin_offset=thin_offset,
@@ -223,7 +223,10 @@ def _pca(
223223
random_seed=random_seed,
224224
chunks=chunks,
225225
inline_array=inline_array,
226+
return_dataset=True,
226227
)
228+
gn = ds_diplotypes["call_diplotype"].values
229+
samples = ds_diplotypes["sample_id"].values.astype("U")
227230

228231
with self._spinner(desc="Compute PCA"):
229232
# Exclude any samples prior to computing PCA.

malariagen_data/anoph/snp_data.py

Lines changed: 117 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1452,6 +1452,13 @@ def _snp_allele_counts(
14521452
Compute SNP allele counts. This returns the number of times each
14531453
SNP allele was observed in the selected samples.
14541454
""",
1455+
parameters=dict(
1456+
return_dataset="""
1457+
If True, return an xarray dataset containing SNP calls with
1458+
`variant_allele_count` added as an extra data variable. If False
1459+
(default), return a numpy array of allele counts.
1460+
"""
1461+
),
14551462
returns="""
14561463
A numpy array of shape (n_variants, 4), where the first column has
14571464
the reference allele (0) counts, the second column has the first
@@ -1481,7 +1488,26 @@ def snp_allele_counts(
14811488
random_seed: base_params.random_seed = 42,
14821489
inline_array: base_params.inline_array = base_params.inline_array_default,
14831490
chunks: base_params.chunks = base_params.native_chunks,
1484-
) -> np.ndarray:
1491+
return_dataset: bool = False,
1492+
) -> Any:
1493+
if return_dataset:
1494+
ds = self._snp_calls_with_allele_counts(
1495+
region=region,
1496+
sample_sets=sample_sets,
1497+
sample_query=sample_query,
1498+
sample_query_options=sample_query_options,
1499+
sample_indices=sample_indices,
1500+
site_mask=site_mask,
1501+
site_class=site_class,
1502+
cohort_size=cohort_size,
1503+
min_cohort_size=min_cohort_size,
1504+
max_cohort_size=max_cohort_size,
1505+
random_seed=random_seed,
1506+
inline_array=inline_array,
1507+
chunks=chunks,
1508+
)
1509+
return ds
1510+
14851511
# Change this name if you ever change the behaviour of this function,
14861512
# to invalidate any previously cached data.
14871513
name = "snp_allele_counts_v2"
@@ -1555,6 +1581,52 @@ def snp_allele_counts(
15551581
ac = results["ac"]
15561582
return ac
15571583

1584+
def _snp_calls_with_allele_counts(
1585+
self,
1586+
*,
1587+
region,
1588+
sample_sets,
1589+
sample_query,
1590+
sample_query_options,
1591+
sample_indices,
1592+
site_mask,
1593+
site_class,
1594+
cohort_size,
1595+
min_cohort_size,
1596+
max_cohort_size,
1597+
random_seed,
1598+
inline_array,
1599+
chunks,
1600+
) -> xr.Dataset:
1601+
ds = self.snp_calls(
1602+
region=region,
1603+
sample_sets=sample_sets,
1604+
sample_query=sample_query,
1605+
sample_query_options=sample_query_options,
1606+
sample_indices=sample_indices,
1607+
site_mask=site_mask,
1608+
site_class=site_class,
1609+
cohort_size=cohort_size,
1610+
min_cohort_size=min_cohort_size,
1611+
max_cohort_size=max_cohort_size,
1612+
random_seed=random_seed,
1613+
inline_array=inline_array,
1614+
chunks=chunks,
1615+
)
1616+
1617+
gt = allel.GenotypeDaskArray(ds["call_genotype"].data)
1618+
ac = gt.count_alleles(max_allele=3)
1619+
with self._dask_progress(desc="Compute SNP allele counts"):
1620+
ac = ac.compute().values
1621+
1622+
ds = ds.assign(
1623+
variant_allele_count=(
1624+
ds["variant_allele"].dims,
1625+
ac,
1626+
)
1627+
)
1628+
return ds
1629+
15581630
@_check_types
15591631
@doc(
15601632
summary="""
@@ -1897,8 +1969,8 @@ def biallelic_snp_calls(
18971969
sample_query=sample_query, sample_indices=sample_indices
18981970
)
18991971

1900-
# Perform an allele count.
1901-
ac = self.snp_allele_counts(
1972+
# Access SNP calls with allele counts in a single pass.
1973+
ds = self._snp_calls_with_allele_counts(
19021974
region=region,
19031975
sample_sets=sample_sets,
19041976
sample_query=sample_query,
@@ -1913,6 +1985,7 @@ def biallelic_snp_calls(
19131985
inline_array=inline_array,
19141986
chunks=chunks,
19151987
)
1988+
ac = ds["variant_allele_count"].values
19161989

19171990
# Locate biallelic SNPs.
19181991
loc_bi = allel.AlleleCountsArray(ac).is_biallelic()
@@ -1921,23 +1994,6 @@ def biallelic_snp_calls(
19211994
ac_bi = ac[loc_bi]
19221995
allele_mapping = _trim_alleles(ac_bi)
19231996

1924-
# Set up SNP calls.
1925-
ds = self.snp_calls(
1926-
region=region,
1927-
sample_sets=sample_sets,
1928-
sample_query=sample_query,
1929-
sample_query_options=sample_query_options,
1930-
sample_indices=sample_indices,
1931-
site_mask=site_mask,
1932-
site_class=site_class,
1933-
cohort_size=cohort_size,
1934-
min_cohort_size=min_cohort_size,
1935-
max_cohort_size=max_cohort_size,
1936-
random_seed=random_seed,
1937-
inline_array=inline_array,
1938-
chunks=chunks,
1939-
)
1940-
19411997
with self._spinner("Prepare biallelic SNP calls"):
19421998
# Subset to biallelic sites.
19431999
ds_bi = _dask_compress_dataset(ds, indexer=loc_bi, dim="variants")
@@ -2032,13 +2088,22 @@ def biallelic_snp_calls(
20322088
@_check_types
20332089
@doc(
20342090
summary="Load biallelic SNP genotypes.",
2035-
returns=dict(
2036-
gn="""
2037-
An array of shape (variants, samples) where each value counts the
2038-
number of alternate alleles per genotype call.
2039-
""",
2040-
samples="Sample identifiers.",
2091+
parameters=dict(
2092+
return_dataset="""
2093+
If True, return an xarray dataset with `call_diplotype` plus
2094+
`sample_id`, `variant_position`, and `variant_contig`. If False
2095+
(default), return a tuple `(gn, samples)` for backward
2096+
compatibility.
2097+
"""
20412098
),
2099+
returns="""
2100+
If `return_dataset` is False (default), return `(gn, samples)`, where
2101+
`gn` is an array of shape `(variants, samples)` counting alternate
2102+
alleles per genotype call and `samples` contains sample identifiers.
2103+
If `return_dataset` is True, return a dataset containing
2104+
`call_diplotype` with dimensions `(variants, samples)`, plus
2105+
`sample_id`, `variant_position`, and `variant_contig`.
2106+
""",
20422107
)
20432108
def biallelic_diplotypes(
20442109
self,
@@ -2059,7 +2124,8 @@ def biallelic_diplotypes(
20592124
thin_offset: base_params.thin_offset = 0,
20602125
inline_array: base_params.inline_array = base_params.inline_array_default,
20612126
chunks: base_params.chunks = base_params.native_chunks,
2062-
) -> Tuple[np.ndarray, np.ndarray]:
2127+
return_dataset: bool = False,
2128+
) -> Any:
20632129
# Change this name if you ever change the behaviour of this function, to
20642130
# invalidate any previously cached data.
20652131
name = "biallelic_diplotypes_v2"
@@ -2147,6 +2213,22 @@ def biallelic_diplotypes(
21472213
gn = results["gn"]
21482214
samples = results["samples"]
21492215

2216+
if return_dataset:
2217+
ds = xr.Dataset(
2218+
coords={
2219+
"sample_id": ("samples", samples),
2220+
"variant_position": ("variants", results["variant_position"]),
2221+
"variant_contig": ("variants", results["variant_contig"]),
2222+
},
2223+
data_vars={
2224+
"call_diplotype": (
2225+
("variants", "samples"),
2226+
gn,
2227+
)
2228+
},
2229+
)
2230+
return ds
2231+
21502232
return gn, samples
21512233

21522234
def _biallelic_diplotypes(
@@ -2193,6 +2275,8 @@ def _biallelic_diplotypes(
21932275

21942276
# Load sample IDs
21952277
samples = ds["sample_id"].values.astype("U")
2278+
variant_position = ds["variant_position"].values
2279+
variant_contig = ds["variant_contig"].values
21962280

21972281
# Compute diplotypes as the number of alt alleles per genotype call.
21982282
# with missing calls coded as -127.
@@ -2203,4 +2287,9 @@ def _biallelic_diplotypes(
22032287
missing = np.all(ds["call_genotype"].values == -1, axis=2)
22042288
gn[missing] = -127
22052289

2206-
return dict(samples=samples, gn=gn)
2290+
return dict(
2291+
samples=samples,
2292+
gn=gn,
2293+
variant_position=variant_position,
2294+
variant_contig=variant_contig,
2295+
)

tests/anoph/test_snp_data.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1146,6 +1146,19 @@ def check_snp_allele_counts(
11461146
)
11471147
assert_array_equal(ac, ac2)
11481148

1149+
# Check dataset return mode.
1150+
ds_ac = api.snp_allele_counts(
1151+
region=region,
1152+
sample_sets=sample_sets,
1153+
sample_query=sample_query,
1154+
sample_query_options=sample_query_options,
1155+
site_mask=site_mask,
1156+
return_dataset=True,
1157+
)
1158+
assert isinstance(ds_ac, xr.Dataset)
1159+
assert "variant_allele_count" in ds_ac
1160+
assert_array_equal(ds_ac["variant_allele_count"].values, ac)
1161+
11491162

11501163
@parametrize_with_cases(
11511164
"fixture,api", cases=".", filter=~ft.has_tag("single-sampleset")
@@ -1453,6 +1466,22 @@ def check_biallelic_snp_calls_and_diplotypes(
14531466
assert samples.shape[0] == gn.shape[1]
14541467
assert samples.tolist() == expected_samples
14551468

1469+
# Check dataset return mode.
1470+
ds_gn = api.biallelic_diplotypes(
1471+
region=region,
1472+
sample_sets=sample_sets,
1473+
site_mask=site_mask,
1474+
site_class=site_class,
1475+
min_minor_ac=min_minor_ac,
1476+
max_missing_an=max_missing_an,
1477+
n_snps=n_snps,
1478+
return_dataset=True,
1479+
)
1480+
assert isinstance(ds_gn, xr.Dataset)
1481+
assert "call_diplotype" in ds_gn
1482+
assert_array_equal(ds_gn["call_diplotype"].values, gn)
1483+
assert ds_gn["sample_id"].values.tolist() == expected_samples
1484+
14561485
return ds
14571486

14581487

0 commit comments

Comments
 (0)