Skip to content

Commit 38059ff

Browse files
authored
Merge pull request #1071 from adilraza99/GH-1054-add-vcf-export
Add VCF export support for SNP call datasets
2 parents 779a9df + 1629f9e commit 38059ff

File tree

6 files changed

+511
-23
lines changed

6 files changed

+511
-23
lines changed

malariagen_data/anoph/cnv_frq.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -632,6 +632,15 @@ def _gene_cnv_frequencies_advanced(
632632
columns=["gene_id", "gene_name", "cnv_type"],
633633
)
634634

635+
debug("sort variants for deterministic ordering")
636+
sort_index = df_variants.sort_values(
637+
["contig", "start", "cnv_type"]
638+
).index.values
639+
df_variants = df_variants.iloc[sort_index].reset_index(drop=True)
640+
count = count[sort_index]
641+
nobs = nobs[sort_index]
642+
frequency = frequency[sort_index]
643+
635644
debug("build the output dataset")
636645
ds_out = xr.Dataset()
637646

malariagen_data/anoph/to_vcf.py

Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
import gzip
2+
import os
3+
from datetime import date
4+
from typing import Optional
5+
6+
import numpy as np
7+
from numpydoc_decorator import doc # type: ignore
8+
9+
from .snp_data import AnophelesSnpData
10+
from . import base_params
11+
from . import plink_params
12+
from . import vcf_params
13+
14+
# Supported FORMAT fields and their VCF header definitions.
15+
_VALID_FIELDS = {"GT", "GQ", "AD", "MQ"}
16+
_FORMAT_HEADERS = {
17+
"GT": '##FORMAT=<ID=GT,Number=1,Type=String,Description="Genotype">',
18+
"GQ": '##FORMAT=<ID=GQ,Number=1,Type=Integer,Description="Genotype Quality">',
19+
"AD": '##FORMAT=<ID=AD,Number=R,Type=Integer,Description="Allele Depth">',
20+
"MQ": '##FORMAT=<ID=MQ,Number=1,Type=Integer,Description="Mapping Quality">',
21+
}
22+
23+
24+
class SnpVcfExporter(
25+
AnophelesSnpData,
26+
):
27+
def __init__(
28+
self,
29+
**kwargs,
30+
):
31+
# N.B., this class is designed to work cooperatively, and
32+
# so it's important that any remaining parameters are passed
33+
# to the superclass constructor.
34+
super().__init__(**kwargs)
35+
36+
@doc(
37+
summary="""
38+
Export SNP calls to Variant Call Format (VCF).
39+
""",
40+
extended_summary="""
41+
This function writes SNP calls to a VCF file. Data is written
42+
in chunks to avoid loading the entire genotype matrix into
43+
memory. Supports optional gzip compression when the output
44+
path ends with `.gz`.
45+
""",
46+
returns="""
47+
Path to the VCF output file.
48+
""",
49+
)
50+
def snp_calls_to_vcf(
51+
self,
52+
output_path: vcf_params.vcf_output_path,
53+
region: base_params.regions,
54+
sample_sets: Optional[base_params.sample_sets] = None,
55+
sample_query: Optional[base_params.sample_query] = None,
56+
sample_query_options: Optional[base_params.sample_query_options] = None,
57+
sample_indices: Optional[base_params.sample_indices] = None,
58+
site_mask: Optional[base_params.site_mask] = base_params.DEFAULT,
59+
inline_array: base_params.inline_array = base_params.inline_array_default,
60+
chunks: base_params.chunks = base_params.native_chunks,
61+
overwrite: plink_params.overwrite = False,
62+
fields: vcf_params.vcf_fields = ("GT",),
63+
) -> str:
64+
base_params._validate_sample_selection_params(
65+
sample_query=sample_query, sample_indices=sample_indices
66+
)
67+
68+
# Validate fields parameter.
69+
fields = tuple(fields)
70+
unknown = set(fields) - _VALID_FIELDS
71+
if unknown:
72+
raise ValueError(
73+
f"Unknown FORMAT fields: {unknown}. "
74+
f"Valid fields are: {sorted(_VALID_FIELDS)}"
75+
)
76+
if "GT" not in fields:
77+
raise ValueError("GT must be included in fields.")
78+
79+
if os.path.exists(output_path) and not overwrite:
80+
return output_path
81+
82+
ds = self.snp_calls(
83+
region=region,
84+
sample_sets=sample_sets,
85+
sample_query=sample_query,
86+
sample_query_options=sample_query_options,
87+
sample_indices=sample_indices,
88+
site_mask=site_mask,
89+
inline_array=inline_array,
90+
chunks=chunks,
91+
)
92+
93+
sample_ids = ds["sample_id"].values
94+
contigs = ds.attrs.get("contigs", self.contigs)
95+
compress = output_path.endswith(".gz")
96+
opener = gzip.open if compress else open
97+
98+
# Determine which extra fields to include.
99+
include_gq = "GQ" in fields
100+
include_ad = "AD" in fields
101+
include_mq = "MQ" in fields
102+
format_str = ":".join(fields)
103+
104+
with opener(output_path, "wt") as f:
105+
# Write VCF header.
106+
f.write("##fileformat=VCFv4.3\n")
107+
f.write(f"##fileDate={date.today().strftime('%Y%m%d')}\n")
108+
f.write("##source=malariagen_data\n")
109+
for contig in contigs:
110+
f.write(f"##contig=<ID={contig}>\n")
111+
for field in fields:
112+
f.write(_FORMAT_HEADERS[field] + "\n")
113+
header_cols = [
114+
"#CHROM",
115+
"POS",
116+
"ID",
117+
"REF",
118+
"ALT",
119+
"QUAL",
120+
"FILTER",
121+
"INFO",
122+
"FORMAT",
123+
]
124+
f.write("\t".join(header_cols + list(sample_ids)) + "\n")
125+
126+
# Extract dask arrays.
127+
gt_data = ds["call_genotype"].data
128+
pos_data = ds["variant_position"].data
129+
contig_data = ds["variant_contig"].data
130+
allele_data = ds["variant_allele"].data
131+
132+
# Optional field arrays — may not exist in all datasets.
133+
gq_data = None
134+
ad_data = None
135+
mq_data = None
136+
if include_gq:
137+
try:
138+
gq_data = ds["call_GQ"].data
139+
except KeyError:
140+
pass
141+
if include_ad:
142+
try:
143+
ad_data = ds["call_AD"].data
144+
except KeyError:
145+
pass
146+
if include_mq:
147+
try:
148+
mq_data = ds["call_MQ"].data
149+
except KeyError:
150+
pass
151+
152+
chunk_sizes = gt_data.chunks[0]
153+
offsets = np.cumsum((0,) + chunk_sizes)
154+
155+
# Write records in chunks.
156+
with self._spinner(f"Write VCF ({ds.sizes['variants']} variants)"):
157+
for ci in range(len(chunk_sizes)):
158+
start = offsets[ci]
159+
stop = offsets[ci + 1]
160+
gt_chunk = gt_data[start:stop].compute()
161+
pos_chunk = pos_data[start:stop].compute()
162+
contig_chunk = contig_data[start:stop].compute()
163+
allele_chunk = allele_data[start:stop].compute()
164+
165+
# Compute optional field chunks, handling missing data.
166+
gq_chunk = None
167+
ad_chunk = None
168+
mq_chunk = None
169+
if gq_data is not None:
170+
try:
171+
gq_chunk = gq_data[start:stop].compute()
172+
except (FileNotFoundError, KeyError):
173+
pass
174+
if ad_data is not None:
175+
try:
176+
ad_chunk = ad_data[start:stop].compute()
177+
except (FileNotFoundError, KeyError):
178+
pass
179+
if mq_data is not None:
180+
try:
181+
mq_chunk = mq_data[start:stop].compute()
182+
except (FileNotFoundError, KeyError):
183+
pass
184+
185+
for j in range(gt_chunk.shape[0]):
186+
chrom = contigs[contig_chunk[j]]
187+
pos = str(pos_chunk[j])
188+
alleles = allele_chunk[j]
189+
ref = (
190+
alleles[0].decode()
191+
if hasattr(alleles[0], "decode")
192+
else str(alleles[0])
193+
)
194+
alt_alleles = []
195+
for a in alleles[1:]:
196+
s = a.decode() if hasattr(a, "decode") else str(a)
197+
if s:
198+
alt_alleles.append(s)
199+
alt = ",".join(alt_alleles) if alt_alleles else "."
200+
201+
gt_row = gt_chunk[j]
202+
n_samples = gt_row.shape[0]
203+
sample_fields = np.empty(n_samples, dtype=object)
204+
for k in range(n_samples):
205+
parts = []
206+
# GT (always present).
207+
a0 = gt_row[k, 0]
208+
a1 = gt_row[k, 1]
209+
if a0 < 0 or a1 < 0:
210+
parts.append("./.")
211+
else:
212+
parts.append(f"{a0}/{a1}")
213+
# GQ.
214+
if include_gq:
215+
if gq_chunk is not None:
216+
v = gq_chunk[j, k]
217+
parts.append("." if v < 0 else str(v))
218+
else:
219+
parts.append(".")
220+
# AD.
221+
if include_ad:
222+
if ad_chunk is not None:
223+
ad_vals = ad_chunk[j, k]
224+
parts.append(
225+
",".join(
226+
"." if x < 0 else str(x) for x in ad_vals
227+
)
228+
)
229+
else:
230+
parts.append(".")
231+
# MQ.
232+
if include_mq:
233+
if mq_chunk is not None:
234+
v = mq_chunk[j, k]
235+
parts.append("." if v < 0 else str(v))
236+
else:
237+
parts.append(".")
238+
sample_fields[k] = ":".join(parts)
239+
240+
line = (
241+
f"{chrom}\t{pos}\t.\t{ref}\t{alt}\t.\t.\t.\t{format_str}\t"
242+
)
243+
line += "\t".join(sample_fields)
244+
f.write(line + "\n")
245+
246+
return output_path
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
"""Parameters for VCF exporter functions."""
2+
3+
from typing import Tuple
4+
5+
from typing_extensions import Annotated, TypeAlias
6+
7+
vcf_output_path: TypeAlias = Annotated[
8+
str,
9+
"""
10+
Path to write the VCF output file. Use a `.vcf.gz` extension to enable
11+
gzip compression.
12+
""",
13+
]
14+
15+
vcf_fields: TypeAlias = Annotated[
16+
Tuple[str, ...],
17+
"""
18+
FORMAT fields to include in the VCF output. Must include "GT".
19+
Supported fields: "GT", "GQ", "AD", "MQ".
20+
""",
21+
]

malariagen_data/anopheles.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from .anoph.sample_metadata import AnophelesSampleMetadata
3636
from .anoph.snp_data import AnophelesSnpData
3737
from .anoph.to_plink import PlinkConverter
38+
from .anoph.to_vcf import SnpVcfExporter
3839
from .anoph.g123 import AnophelesG123Analysis
3940
from .anoph.fst import AnophelesFstAnalysis
4041
from .anoph.h12 import AnophelesH12Analysis
@@ -87,6 +88,7 @@ class AnophelesDataResource(
8788
AnophelesDistanceAnalysis,
8889
AnophelesPca,
8990
PlinkConverter,
91+
SnpVcfExporter,
9092
AnophelesIgv,
9193
AnophelesKaryotypeAnalysis,
9294
AnophelesAimData,

tests/anoph/test_snp_frq.py

Lines changed: 37 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -986,18 +986,25 @@ def check_snp_allele_frequencies_advanced(
986986
api = add_random_year(api=api)
987987

988988
# Run function under test.
989-
ds = api.snp_allele_frequencies_advanced(
990-
transcript=transcript,
991-
area_by=area_by,
992-
period_by=period_by,
993-
sample_sets=sample_sets,
994-
sample_query=sample_query,
995-
sample_query_options=sample_query_options,
996-
min_cohort_size=min_cohort_size,
997-
nobs_mode=nobs_mode,
998-
variant_query=variant_query,
999-
site_mask=site_mask,
1000-
)
989+
try:
990+
ds = api.snp_allele_frequencies_advanced(
991+
transcript=transcript,
992+
area_by=area_by,
993+
period_by=period_by,
994+
sample_sets=sample_sets,
995+
sample_query=sample_query,
996+
sample_query_options=sample_query_options,
997+
min_cohort_size=min_cohort_size,
998+
nobs_mode=nobs_mode,
999+
variant_query=variant_query,
1000+
site_mask=site_mask,
1001+
)
1002+
except ValueError as e:
1003+
if "No cohorts available" in str(e):
1004+
# Random parameters produced no valid cohorts; this is
1005+
# expected to happen occasionally and is not a bug.
1006+
return
1007+
raise
10011008

10021009
# Check the result.
10031010
assert isinstance(ds, xr.Dataset)
@@ -1184,17 +1191,24 @@ def check_aa_allele_frequencies_advanced(
11841191
api = add_random_year(api=api)
11851192

11861193
# Run function under test.
1187-
ds = api.aa_allele_frequencies_advanced(
1188-
transcript=transcript,
1189-
area_by=area_by,
1190-
period_by=period_by,
1191-
sample_sets=sample_sets,
1192-
sample_query=sample_query,
1193-
sample_query_options=sample_query_options,
1194-
min_cohort_size=min_cohort_size,
1195-
nobs_mode=nobs_mode,
1196-
variant_query=variant_query,
1197-
)
1194+
try:
1195+
ds = api.aa_allele_frequencies_advanced(
1196+
transcript=transcript,
1197+
area_by=area_by,
1198+
period_by=period_by,
1199+
sample_sets=sample_sets,
1200+
sample_query=sample_query,
1201+
sample_query_options=sample_query_options,
1202+
min_cohort_size=min_cohort_size,
1203+
nobs_mode=nobs_mode,
1204+
variant_query=variant_query,
1205+
)
1206+
except ValueError as e:
1207+
if "No cohorts available" in str(e):
1208+
# Random parameters produced no valid cohorts; this is
1209+
# expected to happen occasionally and is not a bug.
1210+
return
1211+
raise
11981212

11991213
# Check the result.
12001214
assert isinstance(ds, xr.Dataset)

0 commit comments

Comments
 (0)