|
| 1 | +import gzip |
| 2 | +import os |
| 3 | +from datetime import date |
| 4 | +from typing import Optional |
| 5 | + |
| 6 | +import numpy as np |
| 7 | +from numpydoc_decorator import doc # type: ignore |
| 8 | + |
| 9 | +from .snp_data import AnophelesSnpData |
| 10 | +from . import base_params |
| 11 | +from . import plink_params |
| 12 | +from . import vcf_params |
| 13 | + |
| 14 | +# Supported FORMAT fields and their VCF header definitions. |
| 15 | +_VALID_FIELDS = {"GT", "GQ", "AD", "MQ"} |
| 16 | +_FORMAT_HEADERS = { |
| 17 | + "GT": '##FORMAT=<ID=GT,Number=1,Type=String,Description="Genotype">', |
| 18 | + "GQ": '##FORMAT=<ID=GQ,Number=1,Type=Integer,Description="Genotype Quality">', |
| 19 | + "AD": '##FORMAT=<ID=AD,Number=R,Type=Integer,Description="Allele Depth">', |
| 20 | + "MQ": '##FORMAT=<ID=MQ,Number=1,Type=Integer,Description="Mapping Quality">', |
| 21 | +} |
| 22 | + |
| 23 | + |
| 24 | +class SnpVcfExporter( |
| 25 | + AnophelesSnpData, |
| 26 | +): |
| 27 | + def __init__( |
| 28 | + self, |
| 29 | + **kwargs, |
| 30 | + ): |
| 31 | + # N.B., this class is designed to work cooperatively, and |
| 32 | + # so it's important that any remaining parameters are passed |
| 33 | + # to the superclass constructor. |
| 34 | + super().__init__(**kwargs) |
| 35 | + |
| 36 | + @doc( |
| 37 | + summary=""" |
| 38 | + Export SNP calls to Variant Call Format (VCF). |
| 39 | + """, |
| 40 | + extended_summary=""" |
| 41 | + This function writes SNP calls to a VCF file. Data is written |
| 42 | + in chunks to avoid loading the entire genotype matrix into |
| 43 | + memory. Supports optional gzip compression when the output |
| 44 | + path ends with `.gz`. |
| 45 | + """, |
| 46 | + returns=""" |
| 47 | + Path to the VCF output file. |
| 48 | + """, |
| 49 | + ) |
| 50 | + def snp_calls_to_vcf( |
| 51 | + self, |
| 52 | + output_path: vcf_params.vcf_output_path, |
| 53 | + region: base_params.regions, |
| 54 | + sample_sets: Optional[base_params.sample_sets] = None, |
| 55 | + sample_query: Optional[base_params.sample_query] = None, |
| 56 | + sample_query_options: Optional[base_params.sample_query_options] = None, |
| 57 | + sample_indices: Optional[base_params.sample_indices] = None, |
| 58 | + site_mask: Optional[base_params.site_mask] = base_params.DEFAULT, |
| 59 | + inline_array: base_params.inline_array = base_params.inline_array_default, |
| 60 | + chunks: base_params.chunks = base_params.native_chunks, |
| 61 | + overwrite: plink_params.overwrite = False, |
| 62 | + fields: vcf_params.vcf_fields = ("GT",), |
| 63 | + ) -> str: |
| 64 | + base_params._validate_sample_selection_params( |
| 65 | + sample_query=sample_query, sample_indices=sample_indices |
| 66 | + ) |
| 67 | + |
| 68 | + # Validate fields parameter. |
| 69 | + fields = tuple(fields) |
| 70 | + unknown = set(fields) - _VALID_FIELDS |
| 71 | + if unknown: |
| 72 | + raise ValueError( |
| 73 | + f"Unknown FORMAT fields: {unknown}. " |
| 74 | + f"Valid fields are: {sorted(_VALID_FIELDS)}" |
| 75 | + ) |
| 76 | + if "GT" not in fields: |
| 77 | + raise ValueError("GT must be included in fields.") |
| 78 | + |
| 79 | + if os.path.exists(output_path) and not overwrite: |
| 80 | + return output_path |
| 81 | + |
| 82 | + ds = self.snp_calls( |
| 83 | + region=region, |
| 84 | + sample_sets=sample_sets, |
| 85 | + sample_query=sample_query, |
| 86 | + sample_query_options=sample_query_options, |
| 87 | + sample_indices=sample_indices, |
| 88 | + site_mask=site_mask, |
| 89 | + inline_array=inline_array, |
| 90 | + chunks=chunks, |
| 91 | + ) |
| 92 | + |
| 93 | + sample_ids = ds["sample_id"].values |
| 94 | + contigs = ds.attrs.get("contigs", self.contigs) |
| 95 | + compress = output_path.endswith(".gz") |
| 96 | + opener = gzip.open if compress else open |
| 97 | + |
| 98 | + # Determine which extra fields to include. |
| 99 | + include_gq = "GQ" in fields |
| 100 | + include_ad = "AD" in fields |
| 101 | + include_mq = "MQ" in fields |
| 102 | + format_str = ":".join(fields) |
| 103 | + |
| 104 | + with opener(output_path, "wt") as f: |
| 105 | + # Write VCF header. |
| 106 | + f.write("##fileformat=VCFv4.3\n") |
| 107 | + f.write(f"##fileDate={date.today().strftime('%Y%m%d')}\n") |
| 108 | + f.write("##source=malariagen_data\n") |
| 109 | + for contig in contigs: |
| 110 | + f.write(f"##contig=<ID={contig}>\n") |
| 111 | + for field in fields: |
| 112 | + f.write(_FORMAT_HEADERS[field] + "\n") |
| 113 | + header_cols = [ |
| 114 | + "#CHROM", |
| 115 | + "POS", |
| 116 | + "ID", |
| 117 | + "REF", |
| 118 | + "ALT", |
| 119 | + "QUAL", |
| 120 | + "FILTER", |
| 121 | + "INFO", |
| 122 | + "FORMAT", |
| 123 | + ] |
| 124 | + f.write("\t".join(header_cols + list(sample_ids)) + "\n") |
| 125 | + |
| 126 | + # Extract dask arrays. |
| 127 | + gt_data = ds["call_genotype"].data |
| 128 | + pos_data = ds["variant_position"].data |
| 129 | + contig_data = ds["variant_contig"].data |
| 130 | + allele_data = ds["variant_allele"].data |
| 131 | + |
| 132 | + # Optional field arrays — may not exist in all datasets. |
| 133 | + gq_data = None |
| 134 | + ad_data = None |
| 135 | + mq_data = None |
| 136 | + if include_gq: |
| 137 | + try: |
| 138 | + gq_data = ds["call_GQ"].data |
| 139 | + except KeyError: |
| 140 | + pass |
| 141 | + if include_ad: |
| 142 | + try: |
| 143 | + ad_data = ds["call_AD"].data |
| 144 | + except KeyError: |
| 145 | + pass |
| 146 | + if include_mq: |
| 147 | + try: |
| 148 | + mq_data = ds["call_MQ"].data |
| 149 | + except KeyError: |
| 150 | + pass |
| 151 | + |
| 152 | + chunk_sizes = gt_data.chunks[0] |
| 153 | + offsets = np.cumsum((0,) + chunk_sizes) |
| 154 | + |
| 155 | + # Write records in chunks. |
| 156 | + with self._spinner(f"Write VCF ({ds.sizes['variants']} variants)"): |
| 157 | + for ci in range(len(chunk_sizes)): |
| 158 | + start = offsets[ci] |
| 159 | + stop = offsets[ci + 1] |
| 160 | + gt_chunk = gt_data[start:stop].compute() |
| 161 | + pos_chunk = pos_data[start:stop].compute() |
| 162 | + contig_chunk = contig_data[start:stop].compute() |
| 163 | + allele_chunk = allele_data[start:stop].compute() |
| 164 | + |
| 165 | + # Compute optional field chunks, handling missing data. |
| 166 | + gq_chunk = None |
| 167 | + ad_chunk = None |
| 168 | + mq_chunk = None |
| 169 | + if gq_data is not None: |
| 170 | + try: |
| 171 | + gq_chunk = gq_data[start:stop].compute() |
| 172 | + except (FileNotFoundError, KeyError): |
| 173 | + pass |
| 174 | + if ad_data is not None: |
| 175 | + try: |
| 176 | + ad_chunk = ad_data[start:stop].compute() |
| 177 | + except (FileNotFoundError, KeyError): |
| 178 | + pass |
| 179 | + if mq_data is not None: |
| 180 | + try: |
| 181 | + mq_chunk = mq_data[start:stop].compute() |
| 182 | + except (FileNotFoundError, KeyError): |
| 183 | + pass |
| 184 | + |
| 185 | + for j in range(gt_chunk.shape[0]): |
| 186 | + chrom = contigs[contig_chunk[j]] |
| 187 | + pos = str(pos_chunk[j]) |
| 188 | + alleles = allele_chunk[j] |
| 189 | + ref = ( |
| 190 | + alleles[0].decode() |
| 191 | + if hasattr(alleles[0], "decode") |
| 192 | + else str(alleles[0]) |
| 193 | + ) |
| 194 | + alt_alleles = [] |
| 195 | + for a in alleles[1:]: |
| 196 | + s = a.decode() if hasattr(a, "decode") else str(a) |
| 197 | + if s: |
| 198 | + alt_alleles.append(s) |
| 199 | + alt = ",".join(alt_alleles) if alt_alleles else "." |
| 200 | + |
| 201 | + gt_row = gt_chunk[j] |
| 202 | + n_samples = gt_row.shape[0] |
| 203 | + sample_fields = np.empty(n_samples, dtype=object) |
| 204 | + for k in range(n_samples): |
| 205 | + parts = [] |
| 206 | + # GT (always present). |
| 207 | + a0 = gt_row[k, 0] |
| 208 | + a1 = gt_row[k, 1] |
| 209 | + if a0 < 0 or a1 < 0: |
| 210 | + parts.append("./.") |
| 211 | + else: |
| 212 | + parts.append(f"{a0}/{a1}") |
| 213 | + # GQ. |
| 214 | + if include_gq: |
| 215 | + if gq_chunk is not None: |
| 216 | + v = gq_chunk[j, k] |
| 217 | + parts.append("." if v < 0 else str(v)) |
| 218 | + else: |
| 219 | + parts.append(".") |
| 220 | + # AD. |
| 221 | + if include_ad: |
| 222 | + if ad_chunk is not None: |
| 223 | + ad_vals = ad_chunk[j, k] |
| 224 | + parts.append( |
| 225 | + ",".join( |
| 226 | + "." if x < 0 else str(x) for x in ad_vals |
| 227 | + ) |
| 228 | + ) |
| 229 | + else: |
| 230 | + parts.append(".") |
| 231 | + # MQ. |
| 232 | + if include_mq: |
| 233 | + if mq_chunk is not None: |
| 234 | + v = mq_chunk[j, k] |
| 235 | + parts.append("." if v < 0 else str(v)) |
| 236 | + else: |
| 237 | + parts.append(".") |
| 238 | + sample_fields[k] = ":".join(parts) |
| 239 | + |
| 240 | + line = ( |
| 241 | + f"{chrom}\t{pos}\t.\t{ref}\t{alt}\t.\t.\t.\t{format_str}\t" |
| 242 | + ) |
| 243 | + line += "\t".join(sample_fields) |
| 244 | + f.write(line + "\n") |
| 245 | + |
| 246 | + return output_path |
0 commit comments