1111from . import plink_params
1212from . import vcf_params
1313
14+ # Supported FORMAT fields and their VCF header definitions.
15+ _VALID_FIELDS = {"GT" , "GQ" , "AD" , "MQ" }
16+ _FORMAT_HEADERS = {
17+ "GT" : '##FORMAT=<ID=GT,Number=1,Type=String,Description="Genotype">' ,
18+ "GQ" : '##FORMAT=<ID=GQ,Number=1,Type=Integer,Description="Genotype Quality">' ,
19+ "AD" : '##FORMAT=<ID=AD,Number=R,Type=Integer,Description="Allele Depth">' ,
20+ "MQ" : '##FORMAT=<ID=MQ,Number=1,Type=Integer,Description="Mapping Quality">' ,
21+ }
22+
1423
1524class VcfExporter (
1625 AnophelesSnpData ,
@@ -50,11 +59,23 @@ def snp_calls_to_vcf(
5059 inline_array : base_params .inline_array = base_params .inline_array_default ,
5160 chunks : base_params .chunks = base_params .native_chunks ,
5261 overwrite : plink_params .overwrite = False ,
62+ fields : vcf_params .vcf_fields = ("GT" ,),
5363 ) -> str :
5464 base_params ._validate_sample_selection_params (
5565 sample_query = sample_query , sample_indices = sample_indices
5666 )
5767
68+ # Validate fields parameter.
69+ fields = tuple (fields )
70+ unknown = set (fields ) - _VALID_FIELDS
71+ if unknown :
72+ raise ValueError (
73+ f"Unknown FORMAT fields: { unknown } . "
74+ f"Valid fields are: { sorted (_VALID_FIELDS )} "
75+ )
76+ if "GT" not in fields :
77+ raise ValueError ("GT must be included in fields." )
78+
5879 if os .path .exists (output_path ) and not overwrite :
5980 return output_path
6081
@@ -74,14 +95,21 @@ def snp_calls_to_vcf(
7495 compress = output_path .endswith (".gz" )
7596 opener = gzip .open if compress else open
7697
98+ # Determine which extra fields to include.
99+ include_gq = "GQ" in fields
100+ include_ad = "AD" in fields
101+ include_mq = "MQ" in fields
102+ format_str = ":" .join (fields )
103+
77104 with opener (output_path , "wt" ) as f :
78105 # Write VCF header.
79106 f .write ("##fileformat=VCFv4.3\n " )
80107 f .write (f"##fileDate={ date .today ().strftime ('%Y%m%d' )} \n " )
81108 f .write ("##source=malariagen_data\n " )
82109 for contig in contigs :
83110 f .write (f"##contig=<ID={ contig } >\n " )
84- f .write ('##FORMAT=<ID=GT,Number=1,Type=String,Description="Genotype">\n ' )
111+ for field in fields :
112+ f .write (_FORMAT_HEADERS [field ] + "\n " )
85113 header_cols = [
86114 "#CHROM" ,
87115 "POS" ,
@@ -95,15 +123,36 @@ def snp_calls_to_vcf(
95123 ]
96124 f .write ("\t " .join (header_cols + list (sample_ids )) + "\n " )
97125
98- # Write records in chunks .
126+ # Extract dask arrays .
99127 gt_data = ds ["call_genotype" ].data
100128 pos_data = ds ["variant_position" ].data
101129 contig_data = ds ["variant_contig" ].data
102130 allele_data = ds ["variant_allele" ].data
103131
132+ # Optional field arrays — may not exist in all datasets.
133+ gq_data = None
134+ ad_data = None
135+ mq_data = None
136+ if include_gq :
137+ try :
138+ gq_data = ds ["call_GQ" ].data
139+ except KeyError :
140+ pass
141+ if include_ad :
142+ try :
143+ ad_data = ds ["call_AD" ].data
144+ except KeyError :
145+ pass
146+ if include_mq :
147+ try :
148+ mq_data = ds ["call_MQ" ].data
149+ except KeyError :
150+ pass
151+
104152 chunk_sizes = gt_data .chunks [0 ]
105153 offsets = np .cumsum ((0 ,) + chunk_sizes )
106154
155+ # Write records in chunks.
107156 with self ._spinner (f"Write VCF ({ ds .sizes ['variants' ]} variants)" ):
108157 for ci in range (len (chunk_sizes )):
109158 start = offsets [ci ]
@@ -113,6 +162,26 @@ def snp_calls_to_vcf(
113162 contig_chunk = contig_data [start :stop ].compute ()
114163 allele_chunk = allele_data [start :stop ].compute ()
115164
165+ # Compute optional field chunks, handling missing data.
166+ gq_chunk = None
167+ ad_chunk = None
168+ mq_chunk = None
169+ if gq_data is not None :
170+ try :
171+ gq_chunk = gq_data [start :stop ].compute ()
172+ except (FileNotFoundError , KeyError ):
173+ pass
174+ if ad_data is not None :
175+ try :
176+ ad_chunk = ad_data [start :stop ].compute ()
177+ except (FileNotFoundError , KeyError ):
178+ pass
179+ if mq_data is not None :
180+ try :
181+ mq_chunk = mq_data [start :stop ].compute ()
182+ except (FileNotFoundError , KeyError ):
183+ pass
184+
116185 for j in range (gt_chunk .shape [0 ]):
117186 chrom = contigs [contig_chunk [j ]]
118187 pos = str (pos_chunk [j ])
@@ -130,16 +199,47 @@ def snp_calls_to_vcf(
130199 alt = "," .join (alt_alleles ) if alt_alleles else "."
131200
132201 gt_row = gt_chunk [j ]
133- sample_fields = np .empty (gt_row .shape [0 ], dtype = object )
134- for k in range (gt_row .shape [0 ]):
202+ n_samples = gt_row .shape [0 ]
203+ sample_fields = np .empty (n_samples , dtype = object )
204+ for k in range (n_samples ):
205+ parts = []
206+ # GT (always present).
135207 a0 = gt_row [k , 0 ]
136208 a1 = gt_row [k , 1 ]
137209 if a0 < 0 or a1 < 0 :
138- sample_fields [ k ] = "./."
210+ parts . append ( "./." )
139211 else :
140- sample_fields [k ] = f"{ a0 } /{ a1 } "
141-
142- line = f"{ chrom } \t { pos } \t .\t { ref } \t { alt } \t .\t .\t .\t GT\t "
212+ parts .append (f"{ a0 } /{ a1 } " )
213+ # GQ.
214+ if include_gq :
215+ if gq_chunk is not None :
216+ v = gq_chunk [j , k ]
217+ parts .append ("." if v < 0 else str (v ))
218+ else :
219+ parts .append ("." )
220+ # AD.
221+ if include_ad :
222+ if ad_chunk is not None :
223+ ad_vals = ad_chunk [j , k ]
224+ parts .append (
225+ "," .join (
226+ "." if x < 0 else str (x ) for x in ad_vals
227+ )
228+ )
229+ else :
230+ parts .append ("." )
231+ # MQ.
232+ if include_mq :
233+ if mq_chunk is not None :
234+ v = mq_chunk [j , k ]
235+ parts .append ("." if v < 0 else str (v ))
236+ else :
237+ parts .append ("." )
238+ sample_fields [k ] = ":" .join (parts )
239+
240+ line = (
241+ f"{ chrom } \t { pos } \t .\t { ref } \t { alt } \t .\t .\t .\t { format_str } \t "
242+ )
143243 line += "\t " .join (sample_fields )
144244 f .write (line + "\n " )
145245
0 commit comments