Skip to content

Commit d144d53

Browse files
authored
Merge branch 'master' into GH1223-seed-random
2 parents aa8e011 + 44d93ba commit d144d53

File tree

2 files changed

+129
-71
lines changed

2 files changed

+129
-71
lines changed

malariagen_data/anoph/base_params.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
single_contig_param_type,
1111
single_region_param_type,
1212
chunks_param_type,
13+
Region,
1314
)
1415

1516
contig: TypeAlias = Annotated[
@@ -47,6 +48,8 @@
4748
""",
4849
]
4950

51+
regions_tuple: TypeAlias = Tuple[Region, ...]
52+
5053
release: TypeAlias = Annotated[
5154
Union[str, Sequence[str]],
5255
"Release version identifier.",
@@ -65,6 +68,8 @@
6568
""",
6669
]
6770

71+
sample_sets_tuple: TypeAlias = Tuple[sample_set, ...]
72+
6873
sample_query: TypeAlias = Annotated[
6974
str,
7075
"""
@@ -94,6 +99,8 @@
9499
""",
95100
]
96101

102+
sample_indices_tuple: TypeAlias = Tuple[int, ...]
103+
97104
sample: TypeAlias = Annotated[
98105
Union[str, int],
99106
"Sample identifier or index within sample set.",

malariagen_data/anoph/snp_data.py

Lines changed: 122 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -67,12 +67,15 @@ def __init__(
6767
self._default_site_mask = default_site_mask
6868

6969
# Set up caches.
70-
# TODO review type annotations here, maybe can tighten
71-
self._cache_snp_sites = None
72-
self._cache_snp_genotypes: Dict = dict()
73-
self._cache_site_filters: Dict = dict()
74-
self._cache_site_annotations = None
75-
self._cache_locate_site_class: Dict = dict()
70+
self._cache_snp_sites: Optional[zarr.hierarchy.Group] = None
71+
self._cache_snp_genotypes: Dict[
72+
base_params.sample_set, zarr.hierarchy.Group
73+
] = dict()
74+
self._cache_site_filters: Dict[
75+
base_params.site_mask, zarr.hierarchy.Group
76+
] = dict()
77+
self._cache_site_annotations: Optional[zarr.hierarchy.Group] = None
78+
self._cache_locate_site_class: Dict[Tuple[Any, ...], np.ndarray] = dict()
7679

7780
# Create the SNP-calls cache as a per-instance lru_cache wrapping the
7881
# bound method. Storing it on the instance (rather than using a
@@ -214,7 +217,7 @@ def _site_filters_for_contig(
214217
field: base_params.field,
215218
inline_array: base_params.inline_array,
216219
chunks: base_params.chunks,
217-
):
220+
) -> da.Array:
218221
if contig in self.virtual_contigs:
219222
contigs = self.virtual_contigs[contig]
220223
arrs = [
@@ -245,7 +248,7 @@ def _site_filters_for_region(
245248
field: base_params.field,
246249
inline_array: base_params.inline_array,
247250
chunks: base_params.chunks,
248-
):
251+
) -> da.Array:
249252
d = self._site_filters_for_contig(
250253
contig=region.contig,
251254
mask=mask,
@@ -579,7 +582,7 @@ def _snp_variants_for_contig(
579582
contig: base_params.contig,
580583
inline_array: base_params.inline_array,
581584
chunks: base_params.chunks,
582-
):
585+
) -> xr.Dataset:
583586
if contig in self.virtual_contigs:
584587
contigs = self.virtual_contigs[contig]
585588
datasets = []
@@ -652,7 +655,7 @@ def snp_variants(
652655
site_mask: Optional[base_params.site_mask] = None,
653656
inline_array: base_params.inline_array = base_params.inline_array_default,
654657
chunks: base_params.chunks = base_params.native_chunks,
655-
):
658+
) -> xr.Dataset:
656659
# Normalise parameters.
657660
regions: List[Region] = _parse_multi_region(self, region)
658661
del region
@@ -789,7 +792,7 @@ def _locate_site_class(
789792
site_class: base_params.site_class,
790793
inline_array: base_params.inline_array = base_params.inline_array_default,
791794
chunks: base_params.chunks = base_params.native_chunks,
792-
):
795+
) -> np.ndarray:
793796
# Cache these data in memory to avoid repeated computation.
794797
cache_key = (region, site_mask, site_class)
795798

@@ -1082,11 +1085,11 @@ def snp_calls(
10821085
del site_mask
10831086

10841087
# Convert lists to tuples to avoid CacheMiss "TypeError: unhashable type: 'list'".
1085-
prepared_regions_tuple: Tuple[Region, ...] = tuple(prepared_regions)
1086-
prepared_sample_sets_tuple: Optional[Tuple[str, ...]] = (
1088+
prepared_regions_tuple: base_params.regions_tuple = tuple(prepared_regions)
1089+
prepared_sample_sets_tuple: Optional[base_params.sample_sets_tuple] = (
10871090
tuple(prepared_sample_sets) if prepared_sample_sets is not None else None
10881091
)
1089-
prepared_sample_indices_tuple: Optional[Tuple[int, ...]] = (
1092+
prepared_sample_indices_tuple: Optional[base_params.sample_indices_tuple] = (
10901093
tuple(prepared_sample_indices)
10911094
if prepared_sample_indices is not None
10921095
else None
@@ -1110,18 +1113,19 @@ def snp_calls(
11101113
def _raw_snp_calls(
11111114
self,
11121115
*,
1113-
regions: Tuple[Region, ...],
1114-
sample_sets,
1115-
site_mask,
1116-
site_class,
1117-
inline_array,
1118-
chunks,
1119-
):
1116+
regions: base_params.regions_tuple,
1117+
sample_sets: Optional[base_params.sample_sets_tuple],
1118+
site_mask: Optional[base_params.site_mask],
1119+
site_class: Optional[base_params.site_class],
1120+
inline_array: base_params.inline_array,
1121+
chunks: base_params.chunks,
1122+
) -> xr.Dataset:
11201123
# Access SNP calls and concatenate multiple sample sets and/or regions.
11211124
with self._spinner("Access SNP calls"):
11221125
lx = []
11231126
for r in regions:
11241127
ly = []
1128+
assert sample_sets is not None
11251129
for s in sample_sets:
11261130
y = self._snp_calls_for_contig(
11271131
contig=r.contig,
@@ -1179,18 +1183,18 @@ def _raw_snp_calls(
11791183
def _snp_calls(
11801184
self,
11811185
*,
1182-
regions: Tuple[Region, ...],
1183-
sample_sets,
1184-
sample_indices,
1185-
site_mask,
1186-
site_class,
1187-
cohort_size,
1188-
min_cohort_size,
1189-
max_cohort_size,
1190-
random_seed,
1191-
inline_array,
1192-
chunks,
1193-
):
1186+
regions: base_params.regions_tuple,
1187+
sample_sets: Optional[base_params.sample_sets_tuple],
1188+
sample_indices: Optional[base_params.sample_indices_tuple],
1189+
site_mask: Optional[base_params.site_mask],
1190+
site_class: Optional[base_params.site_class],
1191+
cohort_size: Optional[base_params.cohort_size],
1192+
min_cohort_size: Optional[base_params.min_cohort_size],
1193+
max_cohort_size: Optional[base_params.max_cohort_size],
1194+
random_seed: base_params.random_seed,
1195+
inline_array: base_params.inline_array,
1196+
chunks: base_params.chunks,
1197+
) -> xr.Dataset:
11941198
## Get SNP calls and concatenate multiple sample sets and/or regions.
11951199

11961200
# Note: sample_sets should be "prepared" before being passed to this private function.
@@ -1305,23 +1309,25 @@ def _results_cache_add_analysis_params(self, params: dict):
13051309
def _snp_allele_counts(
13061310
self,
13071311
*,
1308-
region,
1309-
sample_sets,
1310-
sample_indices,
1311-
site_mask,
1312-
site_class,
1313-
cohort_size,
1314-
min_cohort_size,
1315-
max_cohort_size,
1316-
random_seed,
1317-
inline_array,
1318-
chunks,
1319-
):
1312+
region: Union[dict, List[dict]],
1313+
sample_sets: Optional[base_params.sample_sets_tuple],
1314+
sample_indices: Optional[base_params.sample_indices_tuple],
1315+
site_mask: Optional[base_params.site_mask],
1316+
site_class: Optional[base_params.site_class],
1317+
cohort_size: Optional[base_params.cohort_size],
1318+
min_cohort_size: Optional[base_params.min_cohort_size],
1319+
max_cohort_size: Optional[base_params.max_cohort_size],
1320+
random_seed: base_params.random_seed,
1321+
inline_array: base_params.inline_array,
1322+
chunks: base_params.chunks,
1323+
) -> Dict[str, np.ndarray]:
13201324
# Access SNP calls.
1325+
# N.B., snp_calls is a public API with @_check_types, which expects
1326+
# List[int] for sample_indices, not a tuple. Convert back here.
13211327
ds_snps = self.snp_calls(
13221328
region=region,
1323-
sample_sets=sample_sets,
1324-
sample_indices=sample_indices,
1329+
sample_sets=list(sample_sets) if sample_sets is not None else None,
1330+
sample_indices=list(sample_indices) if sample_indices is not None else None,
13251331
site_mask=site_mask,
13261332
site_class=site_class,
13271333
cohort_size=cohort_size,
@@ -1402,6 +1408,15 @@ def snp_allele_counts(
14021408
sample_query_options=sample_query_options,
14031409
sample_indices=sample_indices,
14041410
)
1411+
# Convert lists to tuples to avoid CacheMiss "TypeError: unhashable type: 'list'".
1412+
sample_sets_prepped_tuple: Optional[base_params.sample_sets_tuple] = (
1413+
tuple(sample_sets_prepped) if sample_sets_prepped is not None else None
1414+
)
1415+
sample_indices_prepped_tuple: Optional[base_params.sample_indices_tuple] = (
1416+
tuple(sample_indices_prepped)
1417+
if sample_indices_prepped is not None
1418+
else None
1419+
)
14051420
del sample_sets
14061421
del sample_query
14071422
del sample_query_options
@@ -1412,8 +1427,8 @@ def snp_allele_counts(
14121427
del site_mask
14131428
params = dict(
14141429
region=region_prepped,
1415-
sample_sets=sample_sets_prepped,
1416-
sample_indices=sample_indices_prepped,
1430+
sample_sets=sample_sets_prepped_tuple,
1431+
sample_indices=sample_indices_prepped_tuple,
14171432
site_mask=site_mask_prepped,
14181433
site_class=site_class,
14191434
cohort_size=cohort_size,
@@ -1427,7 +1442,17 @@ def snp_allele_counts(
14271442

14281443
except CacheMiss:
14291444
results = self._snp_allele_counts(
1430-
**params, inline_array=inline_array, chunks=chunks
1445+
inline_array=inline_array,
1446+
chunks=chunks,
1447+
region=region_prepped,
1448+
sample_sets=sample_sets_prepped_tuple,
1449+
sample_indices=sample_indices_prepped_tuple,
1450+
site_mask=site_mask_prepped,
1451+
site_class=site_class,
1452+
cohort_size=cohort_size,
1453+
min_cohort_size=min_cohort_size,
1454+
max_cohort_size=max_cohort_size,
1455+
random_seed=random_seed,
14311456
)
14321457
self.results_cache_set(name=name, params=params, results=results)
14331458

@@ -1964,6 +1989,16 @@ def biallelic_diplotypes(
19641989
prepared_region = self._prep_region_cache_param(region=region)
19651990
prepared_site_mask = self._prep_optional_site_mask_param(site_mask=site_mask)
19661991

1992+
# Convert lists to tuples to avoid CacheMiss "TypeError: unhashable type: 'list'".
1993+
prepared_sample_sets_tuple: Optional[base_params.sample_sets_tuple] = (
1994+
tuple(prepared_sample_sets) if prepared_sample_sets is not None else None
1995+
)
1996+
prepared_sample_indices_tuple: Optional[base_params.sample_indices_tuple] = (
1997+
tuple(prepared_sample_indices)
1998+
if prepared_sample_indices is not None
1999+
else None
2000+
)
2001+
19672002
# Delete original parameters to prevent accidental use.
19682003
del sample_sets
19692004
del sample_query
@@ -1976,8 +2011,8 @@ def biallelic_diplotypes(
19762011
region=prepared_region,
19772012
n_snps=n_snps,
19782013
thin_offset=thin_offset,
1979-
sample_sets=prepared_sample_sets,
1980-
sample_indices=prepared_sample_indices,
2014+
sample_sets=prepared_sample_sets_tuple,
2015+
sample_indices=prepared_sample_indices_tuple,
19812016
site_mask=prepared_site_mask,
19822017
site_class=site_class,
19832018
cohort_size=cohort_size,
@@ -1994,7 +2029,21 @@ def biallelic_diplotypes(
19942029

19952030
except CacheMiss:
19962031
results = self._biallelic_diplotypes(
1997-
inline_array=inline_array, chunks=chunks, **params
2032+
inline_array=inline_array,
2033+
chunks=chunks,
2034+
region=prepared_region,
2035+
sample_sets=prepared_sample_sets_tuple,
2036+
sample_indices=prepared_sample_indices_tuple,
2037+
site_mask=prepared_site_mask,
2038+
site_class=site_class,
2039+
cohort_size=cohort_size,
2040+
min_cohort_size=min_cohort_size,
2041+
max_cohort_size=max_cohort_size,
2042+
random_seed=random_seed,
2043+
n_snps=n_snps,
2044+
thin_offset=thin_offset,
2045+
min_minor_ac=min_minor_ac,
2046+
max_missing_an=max_missing_an,
19982047
)
19992048
self.results_cache_set(name=name, params=params, results=results)
20002049

@@ -2007,29 +2056,31 @@ def biallelic_diplotypes(
20072056
def _biallelic_diplotypes(
20082057
self,
20092058
*,
2010-
region,
2011-
sample_sets,
2012-
sample_indices,
2013-
site_mask,
2014-
site_class,
2015-
cohort_size,
2016-
min_cohort_size,
2017-
max_cohort_size,
2018-
random_seed,
2019-
max_missing_an,
2020-
min_minor_ac,
2021-
n_snps,
2022-
thin_offset,
2023-
inline_array,
2024-
chunks,
2025-
):
2059+
region: Union[dict, List[dict]],
2060+
sample_sets: Optional[base_params.sample_sets_tuple],
2061+
sample_indices: Optional[base_params.sample_indices_tuple],
2062+
site_mask: Optional[base_params.site_mask],
2063+
site_class: Optional[base_params.site_class],
2064+
cohort_size: Optional[base_params.cohort_size],
2065+
min_cohort_size: Optional[base_params.min_cohort_size],
2066+
max_cohort_size: Optional[base_params.max_cohort_size],
2067+
random_seed: base_params.random_seed,
2068+
max_missing_an: Optional[base_params.max_missing_an],
2069+
min_minor_ac: Optional[base_params.min_minor_ac],
2070+
n_snps: Optional[base_params.n_snps],
2071+
thin_offset: base_params.thin_offset,
2072+
inline_array: base_params.inline_array,
2073+
chunks: base_params.chunks,
2074+
) -> Dict[str, np.ndarray]:
20262075
# Note: this function uses sample_indices and should not expect a sample_query.
20272076

20282077
# Access biallelic SNPs.
2078+
# N.B., biallelic_snp_calls is a public API with @_check_types, which
2079+
# expects List[int] for sample_indices, not a tuple. Convert back here.
20292080
ds = self.biallelic_snp_calls(
20302081
region=region,
2031-
sample_sets=sample_sets,
2032-
sample_indices=sample_indices,
2082+
sample_sets=list(sample_sets) if sample_sets is not None else None,
2083+
sample_indices=list(sample_indices) if sample_indices is not None else None,
20332084
site_mask=site_mask,
20342085
site_class=site_class,
20352086
cohort_size=cohort_size,

0 commit comments

Comments
 (0)