Skip to content

Commit 9f13fb0

Browse files
authored
Improve performance when accessing biallelic SNP calls (#623)
* fix bug with biallelic snp calls and variant_allele * optimisations, wip * massive thrash * comment * refactor * deal with strange performance issue in zarr and fsspec * poetry update * tweaks * revert silly mistake * avoid getattr pickle black hole of doom * fix gcs bucket * remove outputs * tweak test * tweak test
1 parent 6c3554b commit 9f13fb0

9 files changed

Lines changed: 672 additions & 3650 deletions

File tree

malariagen_data/anoph/snp_data.py

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,13 @@
1717
DIM_VARIANT,
1818
CacheMiss,
1919
Region,
20-
apply_allele_mapping,
2120
check_types,
2221
da_compress,
2322
da_concat,
2423
da_from_zarr,
24+
dask_apply_allele_mapping,
2525
dask_compress_dataset,
26+
dask_genotype_array_map_alleles,
2627
init_zarr_store,
2728
locate_region,
2829
parse_multi_region,
@@ -565,6 +566,7 @@ def _snp_variants_for_contig(
565566
ref = da_from_zarr(ref_z, inline_array=inline_array, chunks=chunks)
566567
alt = da_from_zarr(alt_z, inline_array=inline_array, chunks=chunks)
567568
variant_allele = da.concatenate([ref[:, None], alt], axis=1)
569+
variant_allele = variant_allele.rechunk((variant_allele.chunks[0], -1))
568570
data_vars["variant_allele"] = [DIM_VARIANT, DIM_ALLELE], variant_allele
569571

570572
# Set up variant_contig.
@@ -1611,7 +1613,7 @@ def biallelic_snp_calls(
16111613

16121614
with self._spinner("Prepare biallelic SNP calls"):
16131615
# Subset to biallelic sites.
1614-
ds_bi = ds.isel(variants=loc_bi)
1616+
ds_bi = dask_compress_dataset(ds, indexer=loc_bi, dim="variants")
16151617

16161618
# Start building a new dataset.
16171619
coords: Dict[str, Any] = dict()
@@ -1624,42 +1626,50 @@ def biallelic_snp_calls(
16241626
coords["variant_contig"] = ("variants",), ds_bi["variant_contig"].data
16251627

16261628
# Store position.
1627-
coords["variant_position"] = ("variants",), ds_bi["variant_position"].data
1629+
variant_position = ds_bi["variant_position"].data
1630+
coords["variant_position"] = ("variants",), variant_position
1631+
1632+
# Prepare allele mapping for dask computations.
1633+
allele_mapping_zarr = zarr.array(allele_mapping)
1634+
allele_mapping_dask = da_from_zarr(
1635+
allele_mapping_zarr, chunks="native", inline_array=True
1636+
)
16281637

16291638
# Store alleles, transformed.
1630-
variant_allele = ds_bi["variant_allele"].data
1631-
variant_allele = variant_allele.rechunk((variant_allele.chunks[0], -1))
1632-
variant_allele_out = da.map_blocks(
1633-
lambda block: apply_allele_mapping(block, allele_mapping, max_allele=1),
1634-
variant_allele,
1635-
dtype=variant_allele.dtype,
1636-
chunks=(variant_allele.chunks[0], [2]),
1639+
variant_allele_dask = ds_bi["variant_allele"].data
1640+
variant_allele_out = dask_apply_allele_mapping(
1641+
variant_allele_dask, allele_mapping_dask, max_allele=1
16371642
)
16381643
data_vars["variant_allele"] = ("variants", "alleles"), variant_allele_out
16391644

1640-
# Store allele counts, transformed, so we don't have to recompute.
1641-
ac_out = apply_allele_mapping(ac_bi, allele_mapping, max_allele=1)
1645+
# Store allele counts, transformed.
1646+
ac_bi_zarr = zarr.array(ac_bi)
1647+
ac_bi_dask = da_from_zarr(ac_bi_zarr, chunks="native", inline_array=True)
1648+
ac_out = dask_apply_allele_mapping(
1649+
ac_bi_dask, allele_mapping_dask, max_allele=1
1650+
)
16421651
data_vars["variant_allele_count"] = ("variants", "alleles"), ac_out
16431652

16441653
# Store genotype calls, transformed.
1645-
gt = ds_bi["call_genotype"].data
1646-
gt_out = allel.GenotypeDaskArray(gt).map_alleles(allele_mapping)
1654+
gt_dask = ds_bi["call_genotype"].data
1655+
gt_out = dask_genotype_array_map_alleles(gt_dask, allele_mapping_dask)
16471656
data_vars["call_genotype"] = (
16481657
(
16491658
"variants",
16501659
"samples",
16511660
"ploidy",
16521661
),
1653-
gt_out.values,
1662+
gt_out,
16541663
)
16551664

16561665
# Build dataset.
16571666
ds_out = xr.Dataset(coords=coords, data_vars=data_vars, attrs=ds.attrs)
16581667

16591668
# Apply conditions.
16601669
if max_missing_an is not None or min_minor_ac is not None:
1670+
ac_out_computed = ac_out.compute()
16611671
loc_out = np.ones(ds_out.sizes["variants"], dtype=bool)
1662-
an = ac_out.sum(axis=1)
1672+
an = ac_out_computed.sum(axis=1)
16631673

16641674
# Apply missingness condition.
16651675
if max_missing_an is not None:
@@ -1673,20 +1683,21 @@ def biallelic_snp_calls(
16731683

16741684
# Apply minor allele count condition.
16751685
if min_minor_ac is not None:
1676-
ac_minor = ac_out.min(axis=1)
1686+
ac_minor = ac_out_computed.min(axis=1)
16771687
if isinstance(min_minor_ac, float):
16781688
ac_minor_frac = ac_minor / an
16791689
loc_minor = ac_minor_frac >= min_minor_ac
16801690
else:
16811691
loc_minor = ac_minor >= min_minor_ac
16821692
loc_out &= loc_minor
16831693

1684-
ds_out = ds_out.isel(variants=loc_out)
1694+
# Apply selection from conditions.
1695+
ds_out = dask_compress_dataset(ds_out, indexer=loc_out, dim="variants")
16851696

16861697
# Try to meet target number of SNPs.
16871698
if n_snps is not None:
16881699
if ds_out.sizes["variants"] > (n_snps * 2):
1689-
# Do some thinning.
1700+
# Apply thinning.
16901701
thin_step = ds_out.sizes["variants"] // n_snps
16911702
loc_thin = slice(thin_offset, None, thin_step)
16921703
ds_out = ds_out.isel(variants=loc_thin)

malariagen_data/util.py

Lines changed: 132 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@
2929
import typeguard
3030
import xarray as xr
3131
import zarr # type: ignore
32+
33+
# zarr >= 2.11.0
34+
from zarr.storage import BaseStore # type: ignore
3235
from fsspec.core import url_to_fs # type: ignore
3336
from fsspec.mapping import FSMap # type: ignore
3437
from numpydoc_decorator.impl import humanize_type # type: ignore
@@ -113,46 +116,40 @@ def unpack_gff3_attributes(df: pd.DataFrame, attributes: Tuple[str, ...]):
113116
return df
114117

115118

116-
# zarr compatibility, version 2.11.0 introduced the BaseStore class
117-
# see also https://github.com/malariagen/malariagen-data-python/issues/129
118-
119-
try:
120-
# zarr >= 2.11.0
121-
from zarr.storage import KVStore # type: ignore
122-
123-
class SafeStore(KVStore):
124-
def __getitem__(self, key):
125-
try:
126-
return self._mutable_mapping[key]
127-
except KeyError as e:
128-
# raise a different error to ensure zarr propagates the exception, rather than filling
129-
raise FileNotFoundError(e)
119+
class SafeStore(BaseStore):
120+
"""This class wraps any zarr store and ensures that missing chunks
121+
will not get automatically filled but will raise an exception. There
122+
should be no missing chunks in any of the datasets we host."""
130123

131-
def __contains__(self, key):
132-
return key in self._mutable_mapping
124+
def __init__(self, store):
125+
self._store = store
133126

134-
except ImportError:
135-
# zarr < 2.11.0
127+
def __getitem__(self, key):
128+
try:
129+
return self._store[key]
130+
except KeyError as e:
131+
# Raise a different error to ensure zarr propagates the exception,
132+
# rather than filling.
133+
raise FileNotFoundError(e)
136134

137-
class SafeStore(Mapping): # type: ignore
138-
def __init__(self, store):
139-
self.store = store
135+
def __getattr__(self, attr):
136+
if attr == "__setstate__":
137+
# Special method called during unpickling, don't pass through.
138+
raise AttributeError(attr)
139+
# Pass through all other attribute access to the wrapped store.
140+
return getattr(self._store, attr)
140141

141-
def __getitem__(self, key):
142-
try:
143-
return self.store[key]
144-
except KeyError as e:
145-
# raise a different error to ensure zarr propagates the exception, rather than filling
146-
raise FileNotFoundError(e)
142+
def __iter__(self):
143+
return iter(self._store)
147144

148-
def __contains__(self, key):
149-
return key in self.store
145+
def __len__(self):
146+
return len(self._store)
150147

151-
def __iter__(self):
152-
return iter(self.store)
148+
def __setitem__(self, item):
149+
raise NotImplementedError
153150

154-
def __len__(self):
155-
return len(self.store)
151+
def __delitem__(self, item):
152+
raise NotImplementedError
156153

157154

158155
class SiteClass(Enum):
@@ -269,7 +266,11 @@ def da_from_zarr(
269266
dask_chunks = chunks
270267

271268
kwargs = dict(
272-
chunks=dask_chunks, fancy=False, lock=False, inline_array=inline_array
269+
inline_array=inline_array,
270+
chunks=dask_chunks,
271+
fancy=True,
272+
lock=False,
273+
asarray=True,
273274
)
274275
try:
275276
d = da.from_array(z, **kwargs)
@@ -301,14 +302,19 @@ def dask_compress_dataset(ds, indexer, dim):
301302
indexer = ds[indexer].data
302303

303304
# sanity checks
304-
assert isinstance(indexer, da.Array)
305305
assert indexer.ndim == 1
306306
assert indexer.dtype == bool
307307
assert indexer.shape[0] == ds.sizes[dim]
308308

309-
# temporarily compute the indexer once, to avoid multiple reads from
310-
# the underlying data
311-
indexer_computed = indexer.compute()
309+
if isinstance(indexer, da.Array):
310+
# temporarily compute the indexer once, to avoid multiple reads from
311+
# the underlying data
312+
indexer_computed = indexer.compute()
313+
else:
314+
assert isinstance(indexer, np.ndarray)
315+
indexer_computed = indexer
316+
indexer_zarr = zarr.array(indexer_computed)
317+
indexer = da_from_zarr(indexer_zarr, chunks="native", inline_array=True)
312318

313319
coords = dict()
314320
for k in ds.coords:
@@ -353,32 +359,36 @@ def da_compress(
353359
):
354360
"""Wrapper for dask.array.compress() which computes chunk sizes faster."""
355361

356-
# sanity checks
362+
# Sanity checks.
363+
assert indexer.ndim == 1
364+
assert indexer.dtype == bool
357365
assert indexer.shape[0] == data.shape[axis]
358366

359-
# useful variables
367+
# Useful variables.
360368
old_chunks = data.chunks
361369
axis_old_chunks = old_chunks[axis]
362370

363-
# load the indexer temporarily for chunk size computations
371+
# Load the indexer temporarily for chunk size computations.
364372
if indexer_computed is None:
365373
indexer_computed = indexer.compute()
366374

367-
# ensure indexer and data are chunked in the same way
375+
# Ensure indexer and data are chunked in the same way.
368376
indexer = indexer.rechunk((axis_old_chunks,))
369377

370-
# apply the indexing operation
378+
# Apply the indexing operation.
371379
v = da.compress(indexer, data, axis=axis)
372380

373-
# need to compute chunks sizes in order to know dimension sizes;
381+
# Need to compute chunks sizes in order to know dimension sizes;
374382
# would normally do v.compute_chunk_sizes() but that is slow for
375-
# multidimensional arrays, so hack something more efficient
376-
383+
# multidimensional arrays, so hack something more efficient.
377384
axis_new_chunks_list = []
378385
slice_start = 0
386+
need_rechunk = False
379387
for old_chunk_size in axis_old_chunks:
380388
slice_stop = slice_start + old_chunk_size
381-
new_chunk_size = np.sum(indexer_computed[slice_start:slice_stop])
389+
new_chunk_size = int(np.sum(indexer_computed[slice_start:slice_stop]))
390+
if new_chunk_size == 0:
391+
need_rechunk = True
382392
axis_new_chunks_list.append(new_chunk_size)
383393
slice_start = slice_stop
384394
axis_new_chunks = tuple(axis_new_chunks_list)
@@ -387,6 +397,23 @@ def da_compress(
387397
)
388398
v._chunks = new_chunks
389399

400+
# Deal with empty chunks, they break reductions.
401+
# Possibly related to https://github.com/dask/dask/issues/10327
402+
# and https://github.com/dask/dask/issues/2794
403+
if need_rechunk:
404+
axis_new_chunks_nonzero = tuple([x for x in axis_new_chunks if x > 0])
405+
# Edge case, all chunks empty:
406+
if len(axis_new_chunks_nonzero) == 0:
407+
# Not much we can do about this, no data.
408+
axis_new_chunks_nonzero = (0,)
409+
new_chunks_nonzero = tuple(
410+
[
411+
axis_new_chunks_nonzero if i == axis else c
412+
for i, c in enumerate(new_chunks)
413+
]
414+
)
415+
v = v.rechunk(new_chunks_nonzero)
416+
390417
return v
391418

392419

@@ -1461,6 +1488,64 @@ def apply_allele_mapping(x, mapping, max_allele):
14611488
return out
14621489

14631490

1491+
def dask_apply_allele_mapping(v, mapping, max_allele):
1492+
assert isinstance(v, da.Array)
1493+
assert isinstance(mapping, da.Array)
1494+
assert v.ndim == 2
1495+
assert mapping.ndim == 2
1496+
assert v.shape[0] == mapping.shape[0]
1497+
v = v.rechunk((v.chunks[0], -1))
1498+
mapping = mapping.rechunk((v.chunks[0], -1))
1499+
out = da.map_blocks(
1500+
lambda xb, mb: apply_allele_mapping(xb, mb, max_allele=max_allele),
1501+
v,
1502+
mapping,
1503+
dtype=v.dtype,
1504+
chunks=(v.chunks[0], [max_allele + 1]),
1505+
)
1506+
return out
1507+
1508+
1509+
def genotype_array_map_alleles(gt, mapping):
1510+
# Transform genotype calls via an allele mapping.
1511+
# N.B., scikit-allel does not handle empty blocks well, so we
1512+
# include some extra logic to handle that better.
1513+
assert isinstance(gt, np.ndarray)
1514+
assert isinstance(mapping, np.ndarray)
1515+
assert gt.ndim == 3
1516+
assert mapping.ndim == 3
1517+
assert gt.shape[0] == mapping.shape[0]
1518+
assert gt.shape[1] > 0
1519+
assert gt.shape[2] == 2
1520+
if gt.size > 0:
1521+
# Block is not empty, can pass through to GenotypeArray.
1522+
assert gt.shape[0] > 0
1523+
m = mapping[:, 0, :]
1524+
out = allel.GenotypeArray(gt).map_alleles(m).values
1525+
else:
1526+
# Block is empty so no alleles need to be mapped.
1527+
assert gt.shape[0] == 0
1528+
out = gt
1529+
return out
1530+
1531+
1532+
def dask_genotype_array_map_alleles(gt, mapping):
1533+
assert isinstance(gt, da.Array)
1534+
assert isinstance(mapping, da.Array)
1535+
assert gt.ndim == 3
1536+
assert mapping.ndim == 2
1537+
assert gt.shape[0] == mapping.shape[0]
1538+
mapping = mapping.rechunk((gt.chunks[0], -1))
1539+
gt_out = da.map_blocks(
1540+
genotype_array_map_alleles,
1541+
gt,
1542+
mapping[:, None, :],
1543+
chunks=gt.chunks,
1544+
dtype=gt.dtype,
1545+
)
1546+
return gt_out
1547+
1548+
14641549
def pandas_apply(f, df, columns):
14651550
"""Optimised alternative to pandas apply."""
14661551
df = df.reset_index(drop=True)

0 commit comments

Comments
 (0)