Skip to content

Commit a834cc0

Browse files
author
khushthecoder
committed
Fix #1308: Scope dask.config.set() to specific operations instead of module import
Move the `split_native_chunks` config override from module-level in ag3.py to context managers within the specific methods that require it. This prevents importing malariagen_data from silently modifying global dask configuration, which could degrade performance for unrelated dask workloads in the same Python session. Affected operations: - util._da_compress(): wraps da.compress() call - snp_data.snp_genotypes(): wraps da.compress() and da.take() calls - snp_data._locate_site_class(): wraps da.take() call
1 parent 9c1bb56 commit a834cc0

3 files changed

Lines changed: 24 additions & 25 deletions

File tree

malariagen_data/ag3.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,10 @@
11
import sys
22

3-
import dask
43
import pandas as pd # type: ignore
54
import plotly.express as px # type: ignore
65
import malariagen_data
76
from .anopheles import AnophelesDataResource
87

9-
# silence dask performance warnings
10-
dask.config.set(**{"array.slicing.split_native_chunks": False}) # type: ignore
11-
128
MAJOR_VERSION_NUMBER = 3
139
MAJOR_VERSION_PATH = "v3"
1410
CONFIG_PATH = "v3-config.json"

malariagen_data/anoph/snp_data.py

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import allel # type: ignore
77
import bokeh
8+
import dask
89
import dask.array as da
910
import numpy as np
1011
import pandas as pd
@@ -72,13 +73,13 @@ def __init__(
7273
self._cache_snp_genotypes: Dict[
7374
base_params.sample_set, zarr.hierarchy.Group
7475
] = dict()
75-
self._cache_site_filters: Dict[
76-
base_params.site_mask, zarr.hierarchy.Group
77-
] = dict()
76+
self._cache_site_filters: Dict[base_params.site_mask, zarr.hierarchy.Group] = (
77+
dict()
78+
)
7879
self._cache_site_annotations: Optional[zarr.hierarchy.Group] = None
79-
self._cache_locate_site_class: OrderedDict[
80-
Tuple[Any, ...], np.ndarray
81-
] = OrderedDict()
80+
self._cache_locate_site_class: OrderedDict[Tuple[Any, ...], np.ndarray] = (
81+
OrderedDict()
82+
)
8283

8384
# Create the SNP-calls cache as a per-instance lru_cache wrapping the
8485
# bound method. Storing it on the instance (rather than using a
@@ -253,8 +254,7 @@ def _site_filters_for_contig(
253254
else:
254255
if contig not in self.contigs:
255256
raise ValueError(
256-
f"Contig {contig!r} not found. "
257-
f"Available contigs: {self.contigs}"
257+
f"Contig {contig!r} not found. Available contigs: {self.contigs}"
258258
)
259259
root = self.open_site_filters(mask=mask)
260260
z = root[f"{contig}/variants/{field}"]
@@ -359,8 +359,7 @@ def _snp_sites_for_contig(
359359
else:
360360
if contig not in self.contigs:
361361
raise ValueError(
362-
f"Contig {contig!r} not found. "
363-
f"Available contigs: {self.contigs}"
362+
f"Contig {contig!r} not found. Available contigs: {self.contigs}"
364363
)
365364
root = self.open_snp_sites()
366365
z = root[f"{contig}/variants/{field}"]
@@ -488,8 +487,7 @@ def _snp_genotypes_for_contig(
488487
else:
489488
if contig not in self.contigs:
490489
raise ValueError(
491-
f"Contig {contig!r} not found. "
492-
f"Available contigs: {self.contigs}"
490+
f"Contig {contig!r} not found. Available contigs: {self.contigs}"
493491
)
494492
root = self.open_snp_genotypes(sample_set=sample_set)
495493
z = root[f"{contig}/calldata/{field}"]
@@ -612,12 +610,14 @@ def snp_genotypes(
612610
)
613611

614612
# Filter the Dask array using the boolean array.
615-
d = da.compress(loc_samples, d, axis=1)
613+
with dask.config.set(**{"array.slicing.split_native_chunks": False}):
614+
d = da.compress(loc_samples, d, axis=1)
616615

617616
# Apply the sample_indices, if there are any.
618617
# Note: this might need to apply to the result of an internal sample_query, e.g. `is_surveillance == True`.
619618
if sample_indices is not None:
620-
d = da.take(d, sample_indices, axis=1)
619+
with dask.config.set(**{"array.slicing.split_native_chunks": False}):
620+
d = da.take(d, sample_indices, axis=1)
621621

622622
return d
623623

@@ -648,8 +648,7 @@ def _snp_variants_for_contig(
648648
else:
649649
if contig not in self.contigs:
650650
raise ValueError(
651-
f"Contig {contig!r} not found. "
652-
f"Available contigs: {self.contigs}"
651+
f"Contig {contig!r} not found. Available contigs: {self.contigs}"
653652
)
654653
coords = dict()
655654
data_vars = dict()
@@ -1021,7 +1020,8 @@ def _locate_site_class(
10211020
chunks=chunks,
10221021
)
10231022
idx = (pos - 1).compute()
1024-
loc_ann = da.take(loc_ann, idx, axis=0)
1023+
with dask.config.set(**{"array.slicing.split_native_chunks": False}):
1024+
loc_ann = da.take(loc_ann, idx, axis=0)
10251025

10261026
# Compute site selection.
10271027
with self._dask_progress(desc=f"Locate {site_class} sites"):
@@ -1066,8 +1066,7 @@ def _snp_calls_for_contig(
10661066
else:
10671067
if contig not in self.contigs:
10681068
raise ValueError(
1069-
f"Contig {contig!r} not found. "
1070-
f"Available contigs: {self.contigs}"
1069+
f"Contig {contig!r} not found. Available contigs: {self.contigs}"
10711070
)
10721071

10731072
coords = dict()

malariagen_data/util.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
colab = None
2121

2222
import allel # type: ignore
23+
import dask
2324
import dask.array as da
2425
from dask.utils import parse_bytes
2526
import numba # type: ignore
@@ -407,8 +408,11 @@ def _da_compress(
407408
else:
408409
indexer = da.from_array(indexer, chunks=(axis_old_chunks,))
409410

410-
# Apply the indexing operation.
411-
v = da.compress(indexer, data, axis=axis)
411+
# Apply the indexing operation, suppressing the dask performance warning
412+
# about split_native_chunks. This config is scoped here rather than at
413+
# module level to avoid silently modifying global dask configuration.
414+
with dask.config.set(**{"array.slicing.split_native_chunks": False}):
415+
v = da.compress(indexer, data, axis=axis)
412416

413417
# Need to compute chunks sizes in order to know dimension sizes;
414418
# would normally do v.compute_chunk_sizes() but that is slow for

0 commit comments

Comments
 (0)