Skip to content

Commit 858c06f

Browse files
authored
Merge pull request #1240 from Yashsingh045/GH1237-seed-random
Standardize test suite randomization using np.random
2 parents fd60009 + 49074e4 commit 858c06f

21 files changed

+422
-423
lines changed

tests/anoph/test_aim_data.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import itertools
2-
import random
32

43
import plotly.graph_objects as go
4+
import numpy as np
55
import pytest
66
import xarray as xr
77
from numpy.testing import assert_array_equal
@@ -88,9 +88,9 @@ def test_aim_calls(aims, ag3_sim_api):
8888
all_releases = api.releases
8989
parametrize_sample_sets = [
9090
None,
91-
random.choice(all_sample_sets),
92-
random.sample(all_sample_sets, 2),
93-
random.choice(all_releases),
91+
str(np.random.choice(all_sample_sets)),
92+
np.random.choice(all_sample_sets, size=2, replace=False).tolist(),
93+
np.random.choice(all_releases),
9494
]
9595

9696
# Parametrize sample_query.
@@ -179,9 +179,9 @@ def test_plot_aim_heatmap(aims, ag3_sim_api):
179179
all_releases = api.releases
180180
parametrize_sample_sets = [
181181
None,
182-
random.choice(all_sample_sets),
183-
random.sample(all_sample_sets, 2),
184-
random.choice(all_releases),
182+
str(np.random.choice(all_sample_sets)),
183+
np.random.choice(all_sample_sets, size=2, replace=False).tolist(),
184+
np.random.choice(all_releases),
185185
]
186186

187187
# Parametrize sample_query.

tests/anoph/test_cnv_data.py

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import random
2-
31
import bokeh.models
42
import dask.array as da
53
import numpy as np
@@ -136,14 +134,15 @@ def test_open_cnv_coverage_calls(fixture, api: AnophelesCnvData):
136134
# Check with a sample set that should not exist
137135
with pytest.raises(ValueError):
138136
root = api.open_cnv_coverage_calls(
139-
sample_set="foobar", analysis=random.choice(api.coverage_calls_analysis_ids)
137+
sample_set="foobar",
138+
analysis=str(np.random.choice(api.coverage_calls_analysis_ids)),
140139
)
141140

142141
# Check with an analysis that should not exist
143142
all_sample_sets = api.sample_sets()["sample_set"].to_list()
144143
with pytest.raises(ValueError):
145144
root = api.open_cnv_coverage_calls(
146-
sample_set=random.choice(all_sample_sets), analysis="foobar"
145+
sample_set=str(np.random.choice(all_sample_sets)), analysis="foobar"
147146
)
148147

149148
# Check with a sample set and analysis that should not exist
@@ -343,15 +342,15 @@ def test_cnv_hmm(fixture, api: AnophelesCnvData):
343342
all_sample_sets = api.sample_sets()["sample_set"].to_list()
344343
parametrize_sample_sets = [
345344
None,
346-
random.choice(all_sample_sets),
347-
random.sample(all_sample_sets, 2),
348-
random.choice(all_releases),
345+
str(np.random.choice(all_sample_sets)),
346+
np.random.choice(all_sample_sets, size=2, replace=False).tolist(),
347+
np.random.choice(all_releases),
349348
]
350349

351350
# Parametrize region.
352351
parametrize_region = [
353352
fixture.random_contig(),
354-
random.sample(api.contigs, 2),
353+
np.random.choice(api.contigs, size=2, replace=False).tolist(),
355354
fixture.random_region_str(),
356355
]
357356

@@ -421,7 +420,7 @@ def test_cnv_hmm(fixture, api: AnophelesCnvData):
421420
def test_cnv_hmm__max_coverage_variance(fixture, api: AnophelesCnvData):
422421
# Set up test.
423422
all_sample_sets = api.sample_sets()["sample_set"].to_list()
424-
sample_set = random.choice(all_sample_sets)
423+
sample_set = str(np.random.choice(all_sample_sets))
425424
region = fixture.random_contig()
426425

427426
# Parametrize max_coverage_variance.
@@ -465,15 +464,17 @@ def test_cnv_hmm__max_coverage_variance(fixture, api: AnophelesCnvData):
465464
def test_cnv_coverage_calls(fixture, api: AnophelesCnvData):
466465
# Parametrize sample_sets.
467466
all_sample_sets = api.sample_sets()["sample_set"].to_list()
468-
parametrize_sample_sets = random.sample(all_sample_sets, 3)
467+
parametrize_sample_sets = np.random.choice(
468+
all_sample_sets, size=3, replace=False
469+
).tolist()
469470

470471
# Parametrize analysis.
471472
parametrize_analysis = api.coverage_calls_analysis_ids
472473

473474
# Parametrize region.
474475
parametrize_region = [
475476
fixture.random_contig(),
476-
random.sample(api.contigs, 2),
477+
np.random.choice(api.contigs, size=2, replace=False).tolist(),
477478
fixture.random_region_str(),
478479
]
479480

@@ -551,15 +552,15 @@ def test_cnv_discordant_read_calls(fixture, api: AnophelesCnvData):
551552
all_sample_sets = api.sample_sets()["sample_set"].to_list()
552553
parametrize_sample_sets = [
553554
None,
554-
random.choice(all_sample_sets),
555-
random.sample(all_sample_sets, 2),
556-
random.choice(all_releases),
555+
str(np.random.choice(all_sample_sets)),
556+
np.random.choice(all_sample_sets, size=2, replace=False).tolist(),
557+
np.random.choice(all_releases),
557558
]
558559

559560
# Parametrize contig.
560561
parametrize_contig = [
561-
random.choice(api.contigs),
562-
random.sample(api.contigs, 2),
562+
str(np.random.choice(api.contigs)),
563+
np.random.choice(api.contigs, size=2, replace=False).tolist(),
563564
]
564565

565566
for sample_sets in parametrize_sample_sets:
@@ -631,13 +632,13 @@ def test_cnv_discordant_read_calls(fixture, api: AnophelesCnvData):
631632
match="No CNV discordant read calls data found|no CNVs available for contig",
632633
):
633634
api.cnv_discordant_read_calls(
634-
contigs="foobar", sample_sets=random.choice(all_sample_sets)
635+
contigs="foobar", sample_sets=str(np.random.choice(all_sample_sets))
635636
)
636637

637638
# Check with a sample set that should not exist
638639
with pytest.raises(ValueError):
639640
api.cnv_discordant_read_calls(
640-
contigs=random.choice(api.contigs), sample_sets="foobar"
641+
contigs=np.random.choice(api.contigs), sample_sets="foobar"
641642
)
642643

643644
# Check with a contig and sample set that should not exist
@@ -649,8 +650,8 @@ def test_cnv_discordant_read_calls(fixture, api: AnophelesCnvData):
649650
def test_cnv_discordant_read_calls_deprecated_contig_alias(
650651
fixture, api: AnophelesCnvData
651652
):
652-
sample_set = random.choice(api.sample_sets()["sample_set"].to_list())
653-
contig = random.choice(api.contigs)
653+
sample_set = str(np.random.choice(api.sample_sets()["sample_set"].to_list()))
654+
contig = str(np.random.choice(api.contigs))
654655
ds_contigs = api.cnv_discordant_read_calls(contigs=contig, sample_sets=sample_set)
655656
with pytest.warns(DeprecationWarning, match="deprecated"):
656657
ds_contig = api.cnv_discordant_read_calls(contig=contig, sample_sets=sample_set)
@@ -821,7 +822,7 @@ def test_cnv_discordant_read_calls__sample_query_options(
821822
def test_plot_cnv_hmm_coverage_track(fixture, api: AnophelesCnvData):
822823
# Set up test.
823824
all_sample_sets = api.sample_sets()["sample_set"].to_list()
824-
sample_set = random.choice(all_sample_sets)
825+
sample_set = str(np.random.choice(all_sample_sets))
825826
region = fixture.random_contig()
826827
df_samples = api.sample_metadata(sample_sets=sample_set)
827828
all_sample_ids = df_samples["sample_id"].values
@@ -874,7 +875,7 @@ def test_plot_cnv_hmm_coverage_track(fixture, api: AnophelesCnvData):
874875
def test_plot_cnv_hmm_coverage(fixture, api: AnophelesCnvData):
875876
# Set up test.
876877
all_sample_sets = api.sample_sets()["sample_set"].to_list()
877-
sample_set = random.choice(all_sample_sets)
878+
sample_set = str(np.random.choice(all_sample_sets))
878879
region = fixture.random_contig()
879880
df_samples = api.sample_metadata(sample_sets=sample_set)
880881
all_sample_ids = df_samples["sample_id"].values
@@ -928,9 +929,9 @@ def test_plot_cnv_hmm_heatmap_track(fixture, api: AnophelesCnvData):
928929
all_sample_sets = api.sample_sets()["sample_set"].to_list()
929930
parametrize_sample_sets = [
930931
None,
931-
random.choice(all_sample_sets),
932-
random.sample(all_sample_sets, 2),
933-
random.choice(all_releases),
932+
str(np.random.choice(all_sample_sets)),
933+
np.random.choice(all_sample_sets, size=2, replace=False).tolist(),
934+
np.random.choice(all_releases),
934935
]
935936

936937
for region in parametrize_region:

tests/anoph/test_cnv_frq.py

Lines changed: 36 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import random
2-
31
import numpy as np
42
import pandas as pd
53
import xarray as xr
@@ -95,10 +93,10 @@ def test_gene_cnv_frequencies_with_str_cohorts(
9593
api: AnophelesCnvFrequencyAnalysis,
9694
cohorts,
9795
):
98-
region = random.choice(api.contigs)
96+
region = str(np.random.choice(api.contigs))
9997
all_sample_sets = api.sample_sets()["sample_set"].to_list()
100-
sample_sets = random.choice(all_sample_sets)
101-
min_cohort_size = random.randint(0, 2)
98+
sample_sets = str(np.random.choice(all_sample_sets))
99+
min_cohort_size = int(np.random.randint(0, 3))
102100

103101
# Set up call params.
104102
params = dict(
@@ -148,8 +146,8 @@ def test_gene_cnv_frequencies_with_min_cohort_size(
148146
):
149147
# Pick test parameters at random.
150148
all_sample_sets = api.sample_sets()["sample_set"].to_list()
151-
sample_sets = random.choice(all_sample_sets)
152-
region = random.choice(api.contigs)
149+
sample_sets = str(np.random.choice(all_sample_sets))
150+
region = str(np.random.choice(api.contigs))
153151
cohorts = "admin1_year"
154152

155153
# Set up call params.
@@ -199,13 +197,13 @@ def test_gene_cnv_frequencies_with_str_cohorts_and_sample_query(
199197
# Pick test parameters at random.
200198
sample_sets = None
201199
min_cohort_size = 0
202-
region = random.choice(api.contigs)
203-
cohorts = random.choice(
204-
["admin1_year", "admin1_month", "admin2_year", "admin2_month"]
200+
region = str(np.random.choice(api.contigs))
201+
cohorts = str(
202+
np.random.choice(["admin1_year", "admin1_month", "admin2_year", "admin2_month"])
205203
)
206204
df_samples = api.sample_metadata(sample_sets=sample_sets)
207205
countries = df_samples["country"].unique()
208-
country = random.choice(countries)
206+
country = str(np.random.choice(countries))
209207
sample_query = f"country == '{country}'"
210208

211209
# Figure out expected cohort labels.
@@ -247,13 +245,13 @@ def test_gene_cnv_frequencies_with_str_cohorts_and_sample_query_options(
247245
# Pick test parameters at random.
248246
sample_sets = None
249247
min_cohort_size = 0
250-
region = random.choice(api.contigs)
251-
cohorts = random.choice(
252-
["admin1_year", "admin1_month", "admin2_year", "admin2_month"]
248+
region = str(np.random.choice(api.contigs))
249+
cohorts = str(
250+
np.random.choice(["admin1_year", "admin1_month", "admin2_year", "admin2_month"])
253251
)
254252
df_samples = api.sample_metadata(sample_sets=sample_sets)
255253
countries = df_samples["country"].unique().tolist()
256-
countries_list = random.sample(countries, 2)
254+
countries_list = np.random.choice(countries, size=2, replace=False).tolist()
257255
sample_query_options = {
258256
"local_dict": {
259257
"countries_list": countries_list,
@@ -303,8 +301,8 @@ def test_gene_cnv_frequencies_with_dict_cohorts(
303301
):
304302
# Pick test parameters at random.
305303
sample_sets = None # all sample sets
306-
min_cohort_size = random.randint(0, 2)
307-
region = random.choice(api.contigs)
304+
min_cohort_size = int(np.random.randint(0, 3))
305+
region = str(np.random.choice(api.contigs))
308306

309307
# Create cohorts by country.
310308
df_samples = api.sample_metadata(sample_sets=sample_sets)
@@ -343,10 +341,10 @@ def test_gene_cnv_frequencies_without_drop_invariant(
343341
):
344342
# Pick test parameters at random.
345343
all_sample_sets = api.sample_sets()["sample_set"].to_list()
346-
sample_sets = random.choice(all_sample_sets)
347-
min_cohort_size = random.randint(0, 2)
348-
region = random.choice(api.contigs)
349-
cohorts = random.choice(["admin1_year", "admin2_month", "country"])
344+
sample_sets = str(np.random.choice(all_sample_sets))
345+
min_cohort_size = int(np.random.randint(0, 3))
346+
region = str(np.random.choice(api.contigs))
347+
cohorts = str(np.random.choice(["admin1_year", "admin2_month", "country"]))
350348

351349
# Figure out expected cohort labels.
352350
df_samples = api.sample_metadata(sample_sets=sample_sets)
@@ -398,9 +396,9 @@ def test_gene_cnv_frequencies_with_bad_region(
398396
):
399397
# Pick test parameters at random.
400398
all_sample_sets = api.sample_sets()["sample_set"].to_list()
401-
sample_sets = random.choice(all_sample_sets)
402-
min_cohort_size = random.randint(0, 2)
403-
cohorts = random.choice(["admin1_year", "admin2_month", "country"])
399+
sample_sets = str(np.random.choice(all_sample_sets))
400+
min_cohort_size = int(np.random.randint(0, 3))
401+
cohorts = str(np.random.choice(["admin1_year", "admin2_month", "country"]))
404402

405403
# Set up call params.
406404
params = dict(
@@ -424,9 +422,9 @@ def test_gene_cnv_frequencies_with_max_coverage_variance(
424422
max_coverage_variance,
425423
):
426424
all_sample_sets = api.sample_sets()["sample_set"].to_list()
427-
sample_sets = random.choice(all_sample_sets)
428-
cohorts = random.choice(["admin1_year", "admin2_month", "country"])
429-
region = random.choice(api.contigs)
425+
sample_sets = str(np.random.choice(all_sample_sets))
426+
cohorts = str(np.random.choice(["admin1_year", "admin2_month", "country"]))
427+
region = str(np.random.choice(api.contigs))
430428

431429
params = dict(
432430
region=region,
@@ -503,7 +501,7 @@ def test_gene_cnv_frequencies_advanced_with_sample_query(
503501
all_sample_sets = api.sample_sets()["sample_set"].to_list()
504502
df_samples = api.sample_metadata(sample_sets=all_sample_sets)
505503
countries = df_samples["country"].unique()
506-
country = random.choice(countries)
504+
country = str(np.random.choice(countries))
507505
sample_query = f"country == '{country}'"
508506

509507
check_gene_cnv_frequencies_advanced(
@@ -522,7 +520,7 @@ def test_gene_cnv_frequencies_advanced_with_sample_query_options(
522520
all_sample_sets = api.sample_sets()["sample_set"].to_list()
523521
df_samples = api.sample_metadata(sample_sets=all_sample_sets)
524522
countries = df_samples["country"].unique().tolist()
525-
countries_list = random.sample(countries, 2)
523+
countries_list = np.random.choice(countries, size=2, replace=False).tolist()
526524
sample_query_options = {
527525
"local_dict": {
528526
"countries_list": countries_list,
@@ -549,7 +547,7 @@ def test_gene_cnv_frequencies_advanced_with_min_cohort_size(
549547
all_sample_sets = api.sample_sets()["sample_set"].to_list()
550548
area_by = "admin1_iso"
551549
period_by = "year"
552-
region = random.choice(api.contigs)
550+
region = str(np.random.choice(api.contigs))
553551

554552
if min_cohort_size <= 10:
555553
# Expect this to find at least one cohort, so go ahead with full
@@ -585,7 +583,7 @@ def test_gene_cnv_frequencies_advanced_with_max_coverage_variance(
585583
all_sample_sets = api.sample_sets()["sample_set"].to_list()
586584
area_by = "admin1_iso"
587585
period_by = "year"
588-
region = random.choice(api.contigs)
586+
region = str(np.random.choice(api.contigs))
589587

590588
if max_coverage_variance >= 0.4:
591589
# Expect this to find at least one cohort, so go ahead with full
@@ -620,7 +618,7 @@ def test_gene_cnv_frequencies_advanced_with_nobs_mode(
620618
all_sample_sets = api.sample_sets()["sample_set"].to_list()
621619
area_by = "admin1_iso"
622620
period_by = "year"
623-
region = random.choice(api.contigs)
621+
region = str(np.random.choice(api.contigs))
624622

625623
check_gene_cnv_frequencies_advanced(
626624
api=api,
@@ -642,7 +640,7 @@ def test_gene_cnv_frequencies_advanced_with_variant_query(
642640
all_sample_sets = api.sample_sets()["sample_set"].to_list()
643641
area_by = "admin1_iso"
644642
period_by = "year"
645-
region = random.choice(api.contigs)
643+
region = str(np.random.choice(api.contigs))
646644
variant_query = f"cnv_type == '{variant_query_option}'"
647645

648646
check_gene_cnv_frequencies_advanced(
@@ -710,16 +708,16 @@ def check_gene_cnv_frequencies_advanced(
710708
):
711709
# Pick test parameters at random.
712710
if region is None:
713-
region = random.choice(api.contigs)
711+
region = str(np.random.choice(api.contigs))
714712
if area_by is None:
715-
area_by = random.choice(["country", "admin1_iso", "admin2_name"])
713+
area_by = str(np.random.choice(["country", "admin1_iso", "admin2_name"]))
716714
if period_by is None:
717-
period_by = random.choice(["year", "quarter", "month", "random_year"])
715+
period_by = str(np.random.choice(["year", "quarter", "month", "random_year"]))
718716
if sample_sets is None:
719717
all_sample_sets = api.sample_sets()["sample_set"].to_list()
720-
sample_sets = random.choice(all_sample_sets)
718+
sample_sets = str(np.random.choice(all_sample_sets))
721719
if min_cohort_size is None:
722-
min_cohort_size = random.randint(0, 2)
720+
min_cohort_size = int(np.random.randint(0, 3))
723721

724722
if period_by == "random_year":
725723
# Add a random_year column to the sample metadata, if there isn't already.

0 commit comments

Comments
 (0)