Skip to content

Commit dc16a6b

Browse files
authored
Merge pull request #628 from malariagen/GH410_add_sample_query_options
Add sample_query_options
2 parents 9f1db09 + 67d9020 commit dc16a6b

19 files changed

Lines changed: 785 additions & 37 deletions

malariagen_data/anoph/aim_data.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ def aim_calls(
121121
aims: aim_params.aims,
122122
sample_sets: Optional[base_params.sample_sets] = None,
123123
sample_query: Optional[base_params.sample_query] = None,
124+
sample_query_options: Optional[base_params.sample_query_options] = None,
124125
) -> xr.Dataset:
125126
self._require_aim_analysis()
126127

@@ -144,7 +145,8 @@ def aim_calls(
144145
# Handle sample query.
145146
if sample_query is not None:
146147
df_samples = self.sample_metadata(sample_sets=sample_sets_prepped)
147-
loc_samples = df_samples.eval(sample_query).values
148+
sample_query_options = sample_query_options or {}
149+
loc_samples = df_samples.eval(sample_query, **sample_query_options).values
148150
if np.count_nonzero(loc_samples) == 0:
149151
raise ValueError(f"No samples found for query {sample_query!r}")
150152
ds = ds.isel(samples=loc_samples)
@@ -170,6 +172,7 @@ def plot_aim_heatmap(
170172
aims: aim_params.aims,
171173
sample_sets: Optional[base_params.sample_sets] = None,
172174
sample_query: Optional[base_params.sample_query] = None,
175+
sample_query_options: Optional[base_params.sample_query_options] = None,
173176
sort: bool = True,
174177
row_height: int = 4,
175178
xgap: float = 0,
@@ -183,6 +186,7 @@ def plot_aim_heatmap(
183186
aims=aims,
184187
sample_sets=sample_sets,
185188
sample_query=sample_query,
189+
sample_query_options=sample_query_options,
186190
).compute()
187191
samples = ds["sample_id"].values
188192
variant_contig = ds["variant_contig"].values

malariagen_data/anoph/base_params.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,14 @@
7373
""",
7474
]
7575

76+
sample_query_options: TypeAlias = Annotated[
77+
dict,
78+
"""
79+
A dictionary of arguments that will be passed through to pandas query() or
80+
eval(), e.g. parser, engine, local_dict, global_dict, resolvers.
81+
""",
82+
]
83+
7684
sample_indices: TypeAlias = Annotated[
7785
List[int],
7886
"""

malariagen_data/anoph/cnv_data.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ def cnv_hmm(
177177
region: base_params.regions,
178178
sample_sets: Optional[base_params.sample_sets] = None,
179179
sample_query: Optional[base_params.sample_query] = None,
180+
sample_query_options: Optional[base_params.sample_query_options] = None,
180181
max_coverage_variance: cnv_params.max_coverage_variance = cnv_params.max_coverage_variance_default,
181182
inline_array: base_params.inline_array = base_params.inline_array_default,
182183
chunks: base_params.chunks = base_params.native_chunks,
@@ -241,7 +242,10 @@ def cnv_hmm(
241242
)
242243

243244
debug("apply the query")
244-
loc_query_samples = df_samples_cnv.eval(sample_query).values
245+
sample_query_options = sample_query_options or {}
246+
loc_query_samples = df_samples_cnv.eval(
247+
sample_query, **sample_query_options
248+
).values
245249
if np.count_nonzero(loc_query_samples) == 0:
246250
raise ValueError(f"No samples found for query {sample_query!r}")
247251

@@ -536,6 +540,7 @@ def cnv_discordant_read_calls(
536540
contig: base_params.contigs,
537541
sample_sets: Optional[base_params.sample_sets] = None,
538542
sample_query: Optional[base_params.sample_query] = None,
543+
sample_query_options: Optional[base_params.sample_query_options] = None,
539544
inline_array: base_params.inline_array = base_params.inline_array_default,
540545
chunks: base_params.chunks = base_params.native_chunks,
541546
) -> xr.Dataset:
@@ -588,7 +593,10 @@ def cnv_discordant_read_calls(
588593
)
589594

590595
debug("apply the query")
591-
loc_query_samples = df_samples_cnv.eval(sample_query).values
596+
sample_query_options = sample_query_options or {}
597+
loc_query_samples = df_samples_cnv.eval(
598+
sample_query, **sample_query_options
599+
).values
592600
if np.count_nonzero(loc_query_samples) == 0:
593601
raise ValueError(f"No samples found for query {sample_query!r}")
594602

@@ -801,6 +809,7 @@ def plot_cnv_hmm_heatmap_track(
801809
region: base_params.region,
802810
sample_sets: Optional[base_params.sample_sets] = None,
803811
sample_query: Optional[base_params.sample_query] = None,
812+
sample_query_options: Optional[base_params.sample_query_options] = None,
804813
max_coverage_variance: cnv_params.max_coverage_variance = cnv_params.max_coverage_variance_default,
805814
sizing_mode: gplt_params.sizing_mode = gplt_params.sizing_mode_default,
806815
width: gplt_params.width = gplt_params.width_default,
@@ -823,6 +832,7 @@ def plot_cnv_hmm_heatmap_track(
823832
region=region_prepped,
824833
sample_sets=sample_sets,
825834
sample_query=sample_query,
835+
sample_query_options=sample_query_options,
826836
max_coverage_variance=max_coverage_variance,
827837
)
828838

@@ -942,6 +952,7 @@ def plot_cnv_hmm_heatmap(
942952
region: base_params.region,
943953
sample_sets: Optional[base_params.sample_sets] = None,
944954
sample_query: Optional[base_params.sample_query] = None,
955+
sample_query_options: Optional[base_params.sample_query_options] = None,
945956
max_coverage_variance: cnv_params.max_coverage_variance = cnv_params.max_coverage_variance_default,
946957
sizing_mode: gplt_params.sizing_mode = gplt_params.sizing_mode_default,
947958
width: gplt_params.width = gplt_params.width_default,
@@ -960,6 +971,7 @@ def plot_cnv_hmm_heatmap(
960971
region=region,
961972
sample_sets=sample_sets,
962973
sample_query=sample_query,
974+
sample_query_options=sample_query_options,
963975
max_coverage_variance=max_coverage_variance,
964976
sizing_mode=sizing_mode,
965977
width=width,

malariagen_data/anoph/dipclust.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def plot_diplotype_clustering(
5050
site_mask: Optional[base_params.site_mask] = base_params.DEFAULT,
5151
sample_sets: Optional[base_params.sample_sets] = None,
5252
sample_query: Optional[base_params.sample_query] = None,
53+
sample_query_options: Optional[base_params.sample_query_options] = None,
5354
cohort_size: Optional[base_params.cohort_size] = None,
5455
random_seed: base_params.random_seed = 42,
5556
color: plotly_params.color = None,
@@ -89,14 +90,17 @@ def plot_diplotype_clustering(
8990

9091
# Load sample metadata.
9192
df_samples = self.sample_metadata(
92-
sample_sets=sample_sets, sample_query=sample_query
93+
sample_sets=sample_sets,
94+
sample_query=sample_query,
95+
sample_query_options=sample_query_options,
9396
)
9497

9598
dist, gt_samples, n_snps_used = self.diplotype_pairwise_distances(
9699
region=region,
97100
site_mask=site_mask,
98101
sample_sets=sample_sets,
99102
sample_query=sample_query,
103+
sample_query_options=sample_query_options,
100104
cohort_size=cohort_size,
101105
distance_metric=distance_metric,
102106
random_seed=random_seed,
@@ -196,6 +200,7 @@ def diplotype_pairwise_distances(
196200
site_mask: Optional[base_params.site_mask] = base_params.DEFAULT,
197201
sample_sets: Optional[base_params.sample_sets] = None,
198202
sample_query: Optional[base_params.sample_query] = None,
203+
sample_query_options: Optional[base_params.sample_query_options] = None,
199204
site_class: Optional[base_params.site_class] = None,
200205
cohort_size: Optional[base_params.cohort_size] = None,
201206
distance_metric: dipclust_params.distance_metric = dipclust_params.distance_metric_default,
@@ -215,6 +220,7 @@ def diplotype_pairwise_distances(
215220
site_mask=site_mask,
216221
sample_sets=sample_sets_prepped,
217222
sample_query=sample_query,
223+
sample_query_options=sample_query_options,
218224
site_class=site_class,
219225
cohort_size=cohort_size,
220226
distance_metric=distance_metric,
@@ -245,6 +251,7 @@ def _diplotype_pairwise_distances(
245251
site_mask,
246252
sample_sets,
247253
sample_query,
254+
sample_query_options,
248255
site_class,
249256
cohort_size,
250257
distance_metric,
@@ -261,6 +268,7 @@ def _diplotype_pairwise_distances(
261268
ds_snps = self.snp_calls(
262269
region=region,
263270
sample_query=sample_query,
271+
sample_query_options=sample_query_options,
264272
sample_sets=sample_sets,
265273
site_mask=site_mask,
266274
site_class=site_class,
@@ -310,6 +318,7 @@ def _dipclust_het_bar_trace(
310318
dendro_sample_id_order: np.ndarray,
311319
sample_sets: Optional[base_params.sample_sets],
312320
sample_query: Optional[base_params.sample_query],
321+
sample_query_options: Optional[base_params.sample_query_options],
313322
site_mask: base_params.site_mask,
314323
cohort_size: Optional[base_params.cohort_size],
315324
random_seed: base_params.random_seed,
@@ -320,6 +329,7 @@ def _dipclust_het_bar_trace(
320329
ds_snps = self.snp_calls(
321330
region=region,
322331
sample_query=sample_query,
332+
sample_query_options=sample_query_options,
323333
sample_sets=sample_sets,
324334
cohort_size=cohort_size,
325335
site_mask=site_mask,
@@ -375,6 +385,7 @@ def _dipclust_cnv_bar_trace(
375385
dendro_sample_id_order: np.ndarray,
376386
sample_sets: Optional[base_params.sample_sets],
377387
sample_query: Optional[base_params.sample_query],
388+
sample_query_options: Optional[base_params.sample_query_options],
378389
max_coverage_variance: Optional[cnv_params.max_coverage_variance],
379390
colorscale: Optional[plotly_params.color_continuous_scale],
380391
chunks: base_params.chunks = base_params.native_chunks,
@@ -389,6 +400,7 @@ def _dipclust_cnv_bar_trace(
389400
region=cnv_region,
390401
sample_sets=sample_sets,
391402
sample_query=sample_query,
403+
sample_query_options=sample_query_options,
392404
max_coverage_variance=max_coverage_variance,
393405
chunks=chunks,
394406
inline_array=inline_array,
@@ -438,6 +450,7 @@ def _dipclust_snp_trace(
438450
snp_query: Optional[base_params.snp_query] = AA_CHANGE_QUERY,
439451
sample_sets: Optional[base_params.sample_sets],
440452
sample_query: Optional[base_params.sample_query],
453+
sample_query_options: Optional[base_params.sample_query_options],
441454
site_mask: Optional[base_params.site_mask],
442455
dendro_sample_id_order: np.ndarray,
443456
snp_filter_min_maf: float,
@@ -450,6 +463,7 @@ def _dipclust_snp_trace(
450463
transcript=transcript,
451464
snp_query=snp_query,
452465
sample_query=sample_query,
466+
sample_query_options=sample_query_options,
453467
sample_sets=sample_sets,
454468
site_mask=site_mask,
455469
chunks=chunks,
@@ -555,6 +569,7 @@ def plot_diplotype_clustering_advanced(
555569
site_mask: Optional[base_params.site_mask] = None,
556570
sample_sets: Optional[base_params.sample_sets] = None,
557571
sample_query: Optional[base_params.sample_query] = None,
572+
sample_query_options: Optional[base_params.sample_query_options] = None,
558573
random_seed: base_params.random_seed = 42,
559574
cohort_size: Optional[base_params.cohort_size] = None,
560575
color: plotly_params.color = None,
@@ -594,6 +609,7 @@ def plot_diplotype_clustering_advanced(
594609
region=region,
595610
sample_sets=sample_sets,
596611
sample_query=sample_query,
612+
sample_query_options=sample_query_options,
597613
site_mask=site_mask,
598614
count_sort=count_sort,
599615
distance_metric=distance_metric,
@@ -635,6 +651,7 @@ def plot_diplotype_clustering_advanced(
635651
dendro_sample_id_order=dendro_sample_id_order,
636652
sample_sets=sample_sets,
637653
sample_query=sample_query,
654+
sample_query_options=sample_query_options,
638655
cohort_size=cohort_size,
639656
site_mask=site_mask,
640657
color_continuous_scale=heterozygosity_colorscale,
@@ -651,6 +668,7 @@ def plot_diplotype_clustering_advanced(
651668
dendro_sample_id_order=dendro_sample_id_order,
652669
sample_sets=sample_sets,
653670
sample_query=sample_query,
671+
sample_query_options=sample_query_options,
654672
max_coverage_variance=cnv_max_coverage_variance,
655673
colorscale=cnv_colorscale,
656674
chunks=chunks,
@@ -667,6 +685,7 @@ def plot_diplotype_clustering_advanced(
667685
transcript=snp_transcript,
668686
sample_sets=sample_sets,
669687
sample_query=sample_query,
688+
sample_query_options=sample_query_options,
670689
snp_query=snp_query,
671690
site_mask=site_mask,
672691
dendro_sample_id_order=dendro_sample_id_order,
@@ -693,7 +712,7 @@ def plot_diplotype_clustering_advanced(
693712
height=height,
694713
row_heights=subplot_heights,
695714
sample_sets=sample_sets,
696-
sample_query=sample_query,
715+
sample_query=sample_query, # Only uses query for title.
697716
region=region,
698717
n_snps=n_snps_cluster,
699718
)

malariagen_data/anoph/fst.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,7 @@ def pairwise_average_fst(
415415
cohorts: base_params.cohorts,
416416
sample_sets: Optional[base_params.sample_sets] = None,
417417
sample_query: Optional[base_params.sample_query] = None,
418+
sample_query_options: Optional[base_params.sample_query_options] = None,
418419
cohort_size: Optional[base_params.cohort_size] = fst_params.cohort_size_default,
419420
min_cohort_size: Optional[
420421
base_params.min_cohort_size
@@ -432,6 +433,7 @@ def pairwise_average_fst(
432433
cohorts,
433434
sample_sets=sample_sets,
434435
sample_query=sample_query,
436+
sample_query_options=sample_query_options,
435437
cohort_size=cohort_size,
436438
min_cohort_size=min_cohort_size,
437439
)

0 commit comments

Comments
 (0)