Skip to content

Commit b4a3cc9

Browse files
authored
Further work on scalability for large biallelic genotype data computations (#626)
* back out usage of zarr * revert to native chunks for now * revert to native chunks for now
1 parent 9f13fb0 commit b4a3cc9

11 files changed

Lines changed: 57 additions & 61 deletions

File tree

malariagen_data/ag3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from .anopheles import AnophelesDataResource
99

1010
# silence dask performance warnings
11-
dask.config.set(**{"array.slicing.split_large_chunks": False}) # type: ignore
11+
dask.config.set(**{"array.slicing.split_native_chunks": False}) # type: ignore
1212

1313
MAJOR_VERSION_NUMBER = 3
1414
MAJOR_VERSION_PATH = "v3"

malariagen_data/anoph/base_params.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -248,10 +248,6 @@ def validate_sample_selection_params(
248248
# amounts of data.
249249
native_chunks: chunks = "native"
250250

251-
# Alternative default chunk size, suitable for functions which need to
252-
# scan a large amount of data.
253-
large_chunks: chunks = "300MiB"
254-
255251
gff_attributes: TypeAlias = Annotated[
256252
Optional[Union[Sequence[str], str]],
257253
"""

malariagen_data/anoph/fst.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def fst_gwss(
115115
] = fst_params.max_cohort_size_default,
116116
random_seed: base_params.random_seed = 42,
117117
inline_array: base_params.inline_array = base_params.inline_array_default,
118-
chunks: base_params.chunks = base_params.large_chunks,
118+
chunks: base_params.chunks = base_params.native_chunks,
119119
clip_min: fst_params.clip_min = 0.0,
120120
) -> Tuple[np.ndarray, np.ndarray]:
121121
# Change this name if you ever change the behaviour of this function, to

malariagen_data/anoph/g123.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def g123_gwss(
161161
] = g123_params.max_cohort_size_default,
162162
random_seed: base_params.random_seed = 42,
163163
inline_array: base_params.inline_array = base_params.inline_array_default,
164-
chunks: base_params.chunks = base_params.large_chunks,
164+
chunks: base_params.chunks = base_params.native_chunks,
165165
) -> Tuple[np.ndarray, np.ndarray]:
166166
# Change this name if you ever change the behaviour of this function, to
167167
# invalidate any previously cached data.
@@ -264,7 +264,7 @@ def g123_calibration(
264264
window_sizes: g123_params.window_sizes = g123_params.window_sizes_default,
265265
random_seed: base_params.random_seed = 42,
266266
inline_array: base_params.inline_array = base_params.inline_array_default,
267-
chunks: base_params.chunks = base_params.large_chunks,
267+
chunks: base_params.chunks = base_params.native_chunks,
268268
) -> Mapping[str, np.ndarray]:
269269
# Change this name if you ever change the behaviour of this function, to
270270
# invalidate any previously cached data.
@@ -323,7 +323,7 @@ def plot_g123_gwss_track(
323323
x_range: Optional[gplt_params.x_range] = None,
324324
output_backend: gplt_params.output_backend = gplt_params.output_backend_default,
325325
inline_array: base_params.inline_array = base_params.inline_array_default,
326-
chunks: base_params.chunks = base_params.large_chunks,
326+
chunks: base_params.chunks = base_params.native_chunks,
327327
) -> gplt_params.figure:
328328
# compute G123
329329
x, g123 = self.g123_gwss(
@@ -424,7 +424,7 @@ def plot_g123_gwss(
424424
show: gplt_params.show = True,
425425
output_backend: gplt_params.output_backend = gplt_params.output_backend_default,
426426
inline_array: base_params.inline_array = base_params.inline_array_default,
427-
chunks: base_params.chunks = base_params.large_chunks,
427+
chunks: base_params.chunks = base_params.native_chunks,
428428
) -> gplt_params.figure:
429429
# gwss track
430430
fig1 = self.plot_g123_gwss_track(
@@ -497,7 +497,7 @@ def plot_g123_calibration(
497497
title: Optional[gplt_params.title] = None,
498498
show: gplt_params.show = True,
499499
inline_array: base_params.inline_array = base_params.inline_array_default,
500-
chunks: base_params.chunks = base_params.large_chunks,
500+
chunks: base_params.chunks = base_params.native_chunks,
501501
) -> gplt_params.figure:
502502
# get g123 values
503503
calibration_runs = self.g123_calibration(

malariagen_data/anoph/h12.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def h12_calibration(
8585
] = h12_params.max_cohort_size_default,
8686
window_sizes: h12_params.window_sizes = h12_params.window_sizes_default,
8787
random_seed: base_params.random_seed = 42,
88-
chunks: base_params.chunks = base_params.large_chunks,
88+
chunks: base_params.chunks = base_params.native_chunks,
8989
inline_array: base_params.inline_array = base_params.inline_array_default,
9090
) -> Mapping[str, np.ndarray]:
9191
# Change this name if you ever change the behaviour of this function, to
@@ -143,7 +143,7 @@ def plot_h12_calibration(
143143
random_seed: base_params.random_seed = 42,
144144
title: Optional[str] = None,
145145
show: bool = True,
146-
chunks: base_params.chunks = base_params.large_chunks,
146+
chunks: base_params.chunks = base_params.native_chunks,
147147
inline_array: base_params.inline_array = base_params.inline_array_default,
148148
) -> gplt_params.figure:
149149
# Get H12 values.
@@ -286,7 +286,7 @@ def h12_gwss(
286286
base_params.max_cohort_size
287287
] = h12_params.max_cohort_size_default,
288288
random_seed: base_params.random_seed = 42,
289-
chunks: base_params.chunks = base_params.large_chunks,
289+
chunks: base_params.chunks = base_params.native_chunks,
290290
inline_array: base_params.inline_array = base_params.inline_array_default,
291291
) -> Tuple[np.ndarray, np.ndarray]:
292292
# Change this name if you ever change the behaviour of this function, to
@@ -346,7 +346,7 @@ def plot_h12_gwss_track(
346346
show: gplt_params.show = True,
347347
x_range: Optional[gplt_params.x_range] = None,
348348
output_backend: gplt_params.output_backend = gplt_params.output_backend_default,
349-
chunks: base_params.chunks = base_params.large_chunks,
349+
chunks: base_params.chunks = base_params.native_chunks,
350350
inline_array: base_params.inline_array = base_params.inline_array_default,
351351
) -> gplt_params.figure:
352352
# Compute H12.
@@ -447,7 +447,7 @@ def plot_h12_gwss(
447447
genes_height: gplt_params.genes_height = gplt_params.genes_height_default,
448448
show: gplt_params.show = True,
449449
output_backend: gplt_params.output_backend = gplt_params.output_backend_default,
450-
chunks: base_params.chunks = base_params.large_chunks,
450+
chunks: base_params.chunks = base_params.native_chunks,
451451
inline_array: base_params.inline_array = base_params.inline_array_default,
452452
) -> gplt_params.figure:
453453
# Plot GWSS track.

malariagen_data/anoph/h1x.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def h1x_gwss(
112112
base_params.max_cohort_size
113113
] = h12_params.max_cohort_size_default,
114114
random_seed: base_params.random_seed = 42,
115-
chunks: base_params.chunks = base_params.large_chunks,
115+
chunks: base_params.chunks = base_params.native_chunks,
116116
inline_array: base_params.inline_array = base_params.inline_array_default,
117117
) -> Tuple[np.ndarray, np.ndarray]:
118118
# Change this name if you ever change the behaviour of this function, to
@@ -177,7 +177,7 @@ def plot_h1x_gwss_track(
177177
show: gplt_params.show = True,
178178
x_range: Optional[gplt_params.x_range] = None,
179179
output_backend: gplt_params.output_backend = gplt_params.output_backend_default,
180-
chunks: base_params.chunks = base_params.large_chunks,
180+
chunks: base_params.chunks = base_params.native_chunks,
181181
inline_array: base_params.inline_array = base_params.inline_array_default,
182182
) -> gplt_params.figure:
183183
# Compute H1X.
@@ -283,7 +283,7 @@ def plot_h1x_gwss(
283283
genes_height: gplt_params.genes_height = gplt_params.genes_height_default,
284284
show: gplt_params.show = True,
285285
output_backend: gplt_params.output_backend = gplt_params.output_backend_default,
286-
chunks: base_params.chunks = base_params.large_chunks,
286+
chunks: base_params.chunks = base_params.native_chunks,
287287
inline_array: base_params.inline_array = base_params.inline_array_default,
288288
) -> gplt_params.figure:
289289
# Plot GWSS track.

malariagen_data/anoph/pca.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def pca(
7575
fit_exclude_samples: Optional[base_params.samples] = None,
7676
random_seed: base_params.random_seed = 42,
7777
inline_array: base_params.inline_array = base_params.inline_array_default,
78-
chunks: base_params.chunks = base_params.large_chunks,
78+
chunks: base_params.chunks = base_params.native_chunks,
7979
) -> Tuple[pca_params.df_pca, pca_params.evr]:
8080
# Change this name if you ever change the behaviour of this function, to
8181
# invalidate any previously cached data.

malariagen_data/anoph/snp_data.py

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
DIM_VARIANT,
1818
CacheMiss,
1919
Region,
20+
apply_allele_mapping,
2021
check_types,
2122
da_compress,
2223
da_concat,
@@ -1629,30 +1630,20 @@ def biallelic_snp_calls(
16291630
variant_position = ds_bi["variant_position"].data
16301631
coords["variant_position"] = ("variants",), variant_position
16311632

1632-
# Prepare allele mapping for dask computations.
1633-
allele_mapping_zarr = zarr.array(allele_mapping)
1634-
allele_mapping_dask = da_from_zarr(
1635-
allele_mapping_zarr, chunks="native", inline_array=True
1636-
)
1637-
16381633
# Store alleles, transformed.
16391634
variant_allele_dask = ds_bi["variant_allele"].data
16401635
variant_allele_out = dask_apply_allele_mapping(
1641-
variant_allele_dask, allele_mapping_dask, max_allele=1
1636+
variant_allele_dask, allele_mapping, max_allele=1
16421637
)
16431638
data_vars["variant_allele"] = ("variants", "alleles"), variant_allele_out
16441639

16451640
# Store allele counts, transformed.
1646-
ac_bi_zarr = zarr.array(ac_bi)
1647-
ac_bi_dask = da_from_zarr(ac_bi_zarr, chunks="native", inline_array=True)
1648-
ac_out = dask_apply_allele_mapping(
1649-
ac_bi_dask, allele_mapping_dask, max_allele=1
1650-
)
1641+
ac_out = apply_allele_mapping(ac_bi, allele_mapping, max_allele=1)
16511642
data_vars["variant_allele_count"] = ("variants", "alleles"), ac_out
16521643

16531644
# Store genotype calls, transformed.
16541645
gt_dask = ds_bi["call_genotype"].data
1655-
gt_out = dask_genotype_array_map_alleles(gt_dask, allele_mapping_dask)
1646+
gt_out = dask_genotype_array_map_alleles(gt_dask, allele_mapping)
16561647
data_vars["call_genotype"] = (
16571648
(
16581649
"variants",
@@ -1667,9 +1658,8 @@ def biallelic_snp_calls(
16671658

16681659
# Apply conditions.
16691660
if max_missing_an is not None or min_minor_ac is not None:
1670-
ac_out_computed = ac_out.compute()
16711661
loc_out = np.ones(ds_out.sizes["variants"], dtype=bool)
1672-
an = ac_out_computed.sum(axis=1)
1662+
an = ac_out.sum(axis=1)
16731663

16741664
# Apply missingness condition.
16751665
if max_missing_an is not None:
@@ -1683,7 +1673,7 @@ def biallelic_snp_calls(
16831673

16841674
# Apply minor allele count condition.
16851675
if min_minor_ac is not None:
1686-
ac_minor = ac_out_computed.min(axis=1)
1676+
ac_minor = ac_out.min(axis=1)
16871677
if isinstance(min_minor_ac, float):
16881678
ac_minor_frac = ac_minor / an
16891679
loc_minor = ac_minor_frac >= min_minor_ac

malariagen_data/anopheles.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1622,7 +1622,7 @@ def cohort_diversity_stats(
16221622
random_seed: base_params.random_seed = 42,
16231623
n_jack: base_params.n_jack = 200,
16241624
confidence_level: base_params.confidence_level = 0.95,
1625-
chunks: base_params.chunks = base_params.large_chunks,
1625+
chunks: base_params.chunks = base_params.native_chunks,
16261626
inline_array: base_params.inline_array = base_params.inline_array_default,
16271627
) -> pd.Series:
16281628
debug = self._log.debug
@@ -1728,7 +1728,7 @@ def diversity_stats(
17281728
random_seed: base_params.random_seed = 42,
17291729
n_jack: base_params.n_jack = 200,
17301730
confidence_level: base_params.confidence_level = 0.95,
1731-
chunks: base_params.chunks = base_params.large_chunks,
1731+
chunks: base_params.chunks = base_params.native_chunks,
17321732
inline_array: base_params.inline_array = base_params.inline_array_default,
17331733
) -> pd.DataFrame:
17341734
# Normalise cohorts parameter.
@@ -1933,7 +1933,7 @@ def ihs_gwss(
19331933
base_params.max_cohort_size
19341934
] = ihs_params.max_cohort_size_default,
19351935
random_seed: base_params.random_seed = 42,
1936-
chunks: base_params.chunks = base_params.large_chunks,
1936+
chunks: base_params.chunks = base_params.native_chunks,
19371937
inline_array: base_params.inline_array = base_params.inline_array_default,
19381938
) -> Tuple[np.ndarray, np.ndarray]:
19391939
# change this name if you ever change the behaviour of this function, to
@@ -2110,7 +2110,7 @@ def plot_ihs_gwss_track(
21102110
show: gplt_params.show = True,
21112111
x_range: Optional[gplt_params.x_range] = None,
21122112
output_backend: gplt_params.output_backend = gplt_params.output_backend_default,
2113-
chunks: base_params.chunks = base_params.large_chunks,
2113+
chunks: base_params.chunks = base_params.native_chunks,
21142114
inline_array: base_params.inline_array = base_params.inline_array_default,
21152115
) -> gplt_params.figure:
21162116
# compute ihs
@@ -2251,7 +2251,7 @@ def plot_xpehh_gwss(
22512251
genes_height: gplt_params.genes_height = gplt_params.genes_height_default,
22522252
show: gplt_params.show = True,
22532253
output_backend: gplt_params.output_backend = gplt_params.output_backend_default,
2254-
chunks: base_params.chunks = base_params.large_chunks,
2254+
chunks: base_params.chunks = base_params.native_chunks,
22552255
inline_array: base_params.inline_array = base_params.inline_array_default,
22562256
) -> gplt_params.figure:
22572257
# gwss track
@@ -2350,7 +2350,7 @@ def plot_ihs_gwss(
23502350
genes_height: gplt_params.genes_height = gplt_params.genes_height_default,
23512351
show: gplt_params.show = True,
23522352
output_backend: gplt_params.output_backend = gplt_params.output_backend_default,
2353-
chunks: base_params.chunks = base_params.large_chunks,
2353+
chunks: base_params.chunks = base_params.native_chunks,
23542354
inline_array: base_params.inline_array = base_params.inline_array_default,
23552355
) -> gplt_params.figure:
23562356
# gwss track
@@ -2445,7 +2445,7 @@ def xpehh_gwss(
24452445
base_params.max_cohort_size
24462446
] = xpehh_params.max_cohort_size_default,
24472447
random_seed: base_params.random_seed = 42,
2448-
chunks: base_params.chunks = base_params.large_chunks,
2448+
chunks: base_params.chunks = base_params.native_chunks,
24492449
inline_array: base_params.inline_array = base_params.inline_array_default,
24502450
) -> Tuple[np.ndarray, np.ndarray]:
24512451
# change this name if you ever change the behaviour of this function, to
@@ -2624,7 +2624,7 @@ def plot_xpehh_gwss_track(
26242624
show: gplt_params.show = True,
26252625
x_range: Optional[gplt_params.x_range] = None,
26262626
output_backend: gplt_params.output_backend = gplt_params.output_backend_default,
2627-
chunks: base_params.chunks = base_params.large_chunks,
2627+
chunks: base_params.chunks = base_params.native_chunks,
26282628
inline_array: base_params.inline_array = base_params.inline_array_default,
26292629
) -> gplt_params.figure:
26302630
# compute xpehh
@@ -3269,7 +3269,7 @@ def plot_njt(
32693269
max_cohort_size: Optional[base_params.max_cohort_size] = None,
32703270
random_seed: base_params.random_seed = 42,
32713271
inline_array: base_params.inline_array = base_params.inline_array_default,
3272-
chunks: base_params.chunks = base_params.large_chunks,
3272+
chunks: base_params.chunks = base_params.native_chunks,
32733273
) -> plotly_params.figure:
32743274
from biotite.sequence.phylo import neighbor_joining # type: ignore
32753275
from scipy.spatial.distance import squareform # type: ignore

malariagen_data/util.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -239,10 +239,10 @@ def da_from_zarr(
239239
#
240240
# N.B., only resize chunks in arrays with more than one dimension,
241241
# because resizing the one-dimensional arrays according to the same
242-
# size generally leads to poor performance with our datasets.
242+
# size may lead to poor performance with our datasets.
243243
#
244244
# Also, resize along the first dimension only. Again, this is something
245-
# that generally works well for our datasets.
245+
# that may work well for our datasets.
246246
#
247247
# Note that dask also supports this kind of argument, and so we could
248248
# just pass this through. However, some experiments have found this
@@ -313,8 +313,6 @@ def dask_compress_dataset(ds, indexer, dim):
313313
else:
314314
assert isinstance(indexer, np.ndarray)
315315
indexer_computed = indexer
316-
indexer_zarr = zarr.array(indexer_computed)
317-
indexer = da_from_zarr(indexer_zarr, chunks="native", inline_array=True)
318316

319317
coords = dict()
320318
for k in ds.coords:
@@ -344,15 +342,22 @@ def _dask_compress_dataarray(a, indexer, indexer_computed, dim):
344342

345343
else:
346344
# apply the indexing operation
347-
v = da_compress(
348-
indexer=indexer, data=a.data, axis=axis, indexer_computed=indexer_computed
349-
)
345+
data = a.data
346+
if isinstance(data, da.Array):
347+
v = da_compress(
348+
indexer=indexer,
349+
data=a.data,
350+
axis=axis,
351+
indexer_computed=indexer_computed,
352+
)
353+
else:
354+
v = np.compress(indexer_computed, data, axis=axis)
350355

351356
return v
352357

353358

354359
def da_compress(
355-
indexer: da.Array,
360+
indexer: da.Array | np.ndarray,
356361
data: da.Array,
357362
axis: int,
358363
indexer_computed: Optional[np.ndarray] = None,
@@ -373,7 +378,10 @@ def da_compress(
373378
indexer_computed = indexer.compute()
374379

375380
# Ensure indexer and data are chunked in the same way.
376-
indexer = indexer.rechunk((axis_old_chunks,))
381+
if isinstance(indexer, da.Array):
382+
indexer = indexer.rechunk((axis_old_chunks,))
383+
else:
384+
indexer = da.from_array(indexer, chunks=(axis_old_chunks,))
377385

378386
# Apply the indexing operation.
379387
v = da.compress(indexer, data, axis=axis)
@@ -1490,12 +1498,12 @@ def apply_allele_mapping(x, mapping, max_allele):
14901498

14911499
def dask_apply_allele_mapping(v, mapping, max_allele):
14921500
assert isinstance(v, da.Array)
1493-
assert isinstance(mapping, da.Array)
1501+
assert isinstance(mapping, np.ndarray)
14941502
assert v.ndim == 2
14951503
assert mapping.ndim == 2
14961504
assert v.shape[0] == mapping.shape[0]
14971505
v = v.rechunk((v.chunks[0], -1))
1498-
mapping = mapping.rechunk((v.chunks[0], -1))
1506+
mapping = da.from_array(mapping, chunks=(v.chunks[0], -1))
14991507
out = da.map_blocks(
15001508
lambda xb, mb: apply_allele_mapping(xb, mb, max_allele=max_allele),
15011509
v,
@@ -1531,11 +1539,11 @@ def genotype_array_map_alleles(gt, mapping):
15311539

15321540
def dask_genotype_array_map_alleles(gt, mapping):
15331541
assert isinstance(gt, da.Array)
1534-
assert isinstance(mapping, da.Array)
1542+
assert isinstance(mapping, np.ndarray)
15351543
assert gt.ndim == 3
15361544
assert mapping.ndim == 2
15371545
assert gt.shape[0] == mapping.shape[0]
1538-
mapping = mapping.rechunk((gt.chunks[0], -1))
1546+
mapping = da.from_array(mapping, chunks=(gt.chunks[0], -1))
15391547
gt_out = da.map_blocks(
15401548
genotype_array_map_alleles,
15411549
gt,

0 commit comments

Comments
 (0)