Skip to content

Commit a33425b

Browse files
committed
feat: add fields parameter for VCF FORMAT output
1 parent ef17757 commit a33425b

3 files changed

Lines changed: 164 additions & 8 deletions

File tree

malariagen_data/anoph/to_vcf.py

Lines changed: 108 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,15 @@
1111
from . import plink_params
1212
from . import vcf_params
1313

14+
# Supported FORMAT fields and their VCF header definitions.
15+
_VALID_FIELDS = {"GT", "GQ", "AD", "MQ"}
16+
_FORMAT_HEADERS = {
17+
"GT": '##FORMAT=<ID=GT,Number=1,Type=String,Description="Genotype">',
18+
"GQ": '##FORMAT=<ID=GQ,Number=1,Type=Integer,Description="Genotype Quality">',
19+
"AD": '##FORMAT=<ID=AD,Number=R,Type=Integer,Description="Allele Depth">',
20+
"MQ": '##FORMAT=<ID=MQ,Number=1,Type=Integer,Description="Mapping Quality">',
21+
}
22+
1423

1524
class VcfExporter(
1625
AnophelesSnpData,
@@ -50,11 +59,23 @@ def snp_calls_to_vcf(
5059
inline_array: base_params.inline_array = base_params.inline_array_default,
5160
chunks: base_params.chunks = base_params.native_chunks,
5261
overwrite: plink_params.overwrite = False,
62+
fields: vcf_params.vcf_fields = ("GT",),
5363
) -> str:
5464
base_params._validate_sample_selection_params(
5565
sample_query=sample_query, sample_indices=sample_indices
5666
)
5767

68+
# Validate fields parameter.
69+
fields = tuple(fields)
70+
unknown = set(fields) - _VALID_FIELDS
71+
if unknown:
72+
raise ValueError(
73+
f"Unknown FORMAT fields: {unknown}. "
74+
f"Valid fields are: {sorted(_VALID_FIELDS)}"
75+
)
76+
if "GT" not in fields:
77+
raise ValueError("GT must be included in fields.")
78+
5879
if os.path.exists(output_path) and not overwrite:
5980
return output_path
6081

@@ -74,14 +95,21 @@ def snp_calls_to_vcf(
7495
compress = output_path.endswith(".gz")
7596
opener = gzip.open if compress else open
7697

98+
# Determine which extra fields to include.
99+
include_gq = "GQ" in fields
100+
include_ad = "AD" in fields
101+
include_mq = "MQ" in fields
102+
format_str = ":".join(fields)
103+
77104
with opener(output_path, "wt") as f:
78105
# Write VCF header.
79106
f.write("##fileformat=VCFv4.3\n")
80107
f.write(f"##fileDate={date.today().strftime('%Y%m%d')}\n")
81108
f.write("##source=malariagen_data\n")
82109
for contig in contigs:
83110
f.write(f"##contig=<ID={contig}>\n")
84-
f.write('##FORMAT=<ID=GT,Number=1,Type=String,Description="Genotype">\n')
111+
for field in fields:
112+
f.write(_FORMAT_HEADERS[field] + "\n")
85113
header_cols = [
86114
"#CHROM",
87115
"POS",
@@ -95,15 +123,36 @@ def snp_calls_to_vcf(
95123
]
96124
f.write("\t".join(header_cols + list(sample_ids)) + "\n")
97125

98-
# Write records in chunks.
126+
# Extract dask arrays.
99127
gt_data = ds["call_genotype"].data
100128
pos_data = ds["variant_position"].data
101129
contig_data = ds["variant_contig"].data
102130
allele_data = ds["variant_allele"].data
103131

132+
# Optional field arrays — may not exist in all datasets.
133+
gq_data = None
134+
ad_data = None
135+
mq_data = None
136+
if include_gq:
137+
try:
138+
gq_data = ds["call_GQ"].data
139+
except KeyError:
140+
pass
141+
if include_ad:
142+
try:
143+
ad_data = ds["call_AD"].data
144+
except KeyError:
145+
pass
146+
if include_mq:
147+
try:
148+
mq_data = ds["call_MQ"].data
149+
except KeyError:
150+
pass
151+
104152
chunk_sizes = gt_data.chunks[0]
105153
offsets = np.cumsum((0,) + chunk_sizes)
106154

155+
# Write records in chunks.
107156
with self._spinner(f"Write VCF ({ds.sizes['variants']} variants)"):
108157
for ci in range(len(chunk_sizes)):
109158
start = offsets[ci]
@@ -113,6 +162,26 @@ def snp_calls_to_vcf(
113162
contig_chunk = contig_data[start:stop].compute()
114163
allele_chunk = allele_data[start:stop].compute()
115164

165+
# Compute optional field chunks, handling missing data.
166+
gq_chunk = None
167+
ad_chunk = None
168+
mq_chunk = None
169+
if gq_data is not None:
170+
try:
171+
gq_chunk = gq_data[start:stop].compute()
172+
except (FileNotFoundError, KeyError):
173+
pass
174+
if ad_data is not None:
175+
try:
176+
ad_chunk = ad_data[start:stop].compute()
177+
except (FileNotFoundError, KeyError):
178+
pass
179+
if mq_data is not None:
180+
try:
181+
mq_chunk = mq_data[start:stop].compute()
182+
except (FileNotFoundError, KeyError):
183+
pass
184+
116185
for j in range(gt_chunk.shape[0]):
117186
chrom = contigs[contig_chunk[j]]
118187
pos = str(pos_chunk[j])
@@ -130,16 +199,47 @@ def snp_calls_to_vcf(
130199
alt = ",".join(alt_alleles) if alt_alleles else "."
131200

132201
gt_row = gt_chunk[j]
133-
sample_fields = np.empty(gt_row.shape[0], dtype=object)
134-
for k in range(gt_row.shape[0]):
202+
n_samples = gt_row.shape[0]
203+
sample_fields = np.empty(n_samples, dtype=object)
204+
for k in range(n_samples):
205+
parts = []
206+
# GT (always present).
135207
a0 = gt_row[k, 0]
136208
a1 = gt_row[k, 1]
137209
if a0 < 0 or a1 < 0:
138-
sample_fields[k] = "./."
210+
parts.append("./.")
139211
else:
140-
sample_fields[k] = f"{a0}/{a1}"
141-
142-
line = f"{chrom}\t{pos}\t.\t{ref}\t{alt}\t.\t.\t.\tGT\t"
212+
parts.append(f"{a0}/{a1}")
213+
# GQ.
214+
if include_gq:
215+
if gq_chunk is not None:
216+
v = gq_chunk[j, k]
217+
parts.append("." if v < 0 else str(v))
218+
else:
219+
parts.append(".")
220+
# AD.
221+
if include_ad:
222+
if ad_chunk is not None:
223+
ad_vals = ad_chunk[j, k]
224+
parts.append(
225+
",".join(
226+
"." if x < 0 else str(x) for x in ad_vals
227+
)
228+
)
229+
else:
230+
parts.append(".")
231+
# MQ.
232+
if include_mq:
233+
if mq_chunk is not None:
234+
v = mq_chunk[j, k]
235+
parts.append("." if v < 0 else str(v))
236+
else:
237+
parts.append(".")
238+
sample_fields[k] = ":".join(parts)
239+
240+
line = (
241+
f"{chrom}\t{pos}\t.\t{ref}\t{alt}\t.\t.\t.\t{format_str}\t"
242+
)
143243
line += "\t".join(sample_fields)
144244
f.write(line + "\n")
145245

malariagen_data/anoph/vcf_params.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""Parameters for VCF exporter functions."""
22

3+
from typing import Tuple
4+
35
from typing_extensions import Annotated, TypeAlias
46

57
vcf_output_path: TypeAlias = Annotated[
@@ -9,3 +11,11 @@
911
gzip compression.
1012
""",
1113
]
14+
15+
vcf_fields: TypeAlias = Annotated[
16+
Tuple[str, ...],
17+
"""
18+
FORMAT fields to include in the VCF output. Must include "GT".
19+
Supported fields: "GT", "GQ", "AD", "MQ".
20+
""",
21+
]

tests/anoph/test_vcf_exporter.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,3 +148,49 @@ def test_vcf_exporter_gzip(fixture, api: VcfExporter, tmp_path):
148148
with gzip.open(output_path, "rt") as f:
149149
first_line = f.readline()
150150
assert first_line.strip() == "##fileformat=VCFv4.3"
151+
152+
153+
@parametrize_with_cases("fixture,api", cases=".")
154+
def test_vcf_exporter_fields(fixture, api: VcfExporter, tmp_path):
155+
region = api.contigs[0]
156+
157+
# Test with additional FORMAT fields.
158+
output_path = str(tmp_path / "test_fields.vcf")
159+
api.snp_calls_to_vcf(
160+
output_path=output_path,
161+
region=region,
162+
fields=("GT", "GQ"),
163+
)
164+
165+
with open(output_path) as f:
166+
lines = f.readlines()
167+
168+
# Check FORMAT header lines.
169+
format_lines = [line for line in lines if line.startswith("##FORMAT")]
170+
assert len(format_lines) == 2
171+
assert any("ID=GT" in line for line in format_lines)
172+
assert any("ID=GQ" in line for line in format_lines)
173+
174+
# Check FORMAT column value.
175+
data_lines = [line for line in lines if not line.startswith("#")]
176+
assert len(data_lines) > 0
177+
first_data = data_lines[0].strip().split("\t")
178+
assert first_data[8] == "GT:GQ"
179+
180+
# Each sample field should have two colon-separated values.
181+
for sample_val in first_data[9:]:
182+
parts = sample_val.split(":")
183+
assert len(parts) == 2, f"Expected GT:GQ, got {sample_val!r}"
184+
185+
186+
@parametrize_with_cases("fixture,api", cases=".")
187+
def test_vcf_exporter_fields_gt_required(fixture, api: VcfExporter, tmp_path):
188+
region = api.contigs[0]
189+
output_path = str(tmp_path / "test_no_gt.vcf")
190+
191+
with pytest.raises(ValueError, match="GT must be included"):
192+
api.snp_calls_to_vcf(
193+
output_path=output_path,
194+
region=region,
195+
fields=("GQ",),
196+
)

0 commit comments

Comments
 (0)