Skip to content

Commit ae1042f

Browse files
committed
Add type annotations to snp_data.py
1 parent b5f8b6e commit ae1042f

1 file changed

Lines changed: 61 additions & 58 deletions

File tree

malariagen_data/anoph/snp_data.py

Lines changed: 61 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -67,12 +67,15 @@ def __init__(
6767
self._default_site_mask = default_site_mask
6868

6969
# Set up caches.
70-
# TODO review type annotations here, maybe can tighten
71-
self._cache_snp_sites = None
72-
self._cache_snp_genotypes: Dict = dict()
73-
self._cache_site_filters: Dict = dict()
74-
self._cache_site_annotations = None
75-
self._cache_locate_site_class: Dict = dict()
70+
self._cache_snp_sites: Optional[zarr.hierarchy.Group] = None
71+
self._cache_snp_genotypes: Dict[base_params.sample_set, zarr.hierarchy.Group] = (
72+
dict()
73+
)
74+
self._cache_site_filters: Dict[base_params.site_mask, zarr.hierarchy.Group] = (
75+
dict()
76+
)
77+
self._cache_site_annotations: Optional[zarr.hierarchy.Group] = None
78+
self._cache_locate_site_class: Dict[Tuple[Any, ...], np.ndarray] = dict()
7679

7780
# Create the SNP-calls cache as a per-instance lru_cache wrapping the
7881
# bound method. Storing it on the instance (rather than using a
@@ -108,7 +111,7 @@ def _prep_site_mask_param(
108111
self,
109112
*,
110113
site_mask: base_params.site_mask,
111-
) -> base_params.site_mask:
114+
) -> str:
112115
if site_mask == base_params.DEFAULT:
113116
# Use whatever is the default site mask for this data resource.
114117
assert self._default_site_mask is not None
@@ -122,7 +125,7 @@ def _prep_optional_site_mask_param(
122125
self,
123126
*,
124127
site_mask: Optional[base_params.site_mask],
125-
) -> Optional[base_params.site_mask]:
128+
) -> Optional[str]:
126129
if site_mask is None:
127130
# This is allowed, it means don't apply any site mask to the data.
128131
return None
@@ -214,7 +217,7 @@ def _site_filters_for_contig(
214217
field: base_params.field,
215218
inline_array: base_params.inline_array,
216219
chunks: base_params.chunks,
217-
):
220+
) -> da.Array:
218221
if contig in self.virtual_contigs:
219222
contigs = self.virtual_contigs[contig]
220223
arrs = [
@@ -245,7 +248,7 @@ def _site_filters_for_region(
245248
field: base_params.field,
246249
inline_array: base_params.inline_array,
247250
chunks: base_params.chunks,
248-
):
251+
) -> da.Array:
249252
d = self._site_filters_for_contig(
250253
contig=region.contig,
251254
mask=mask,
@@ -579,7 +582,7 @@ def _snp_variants_for_contig(
579582
contig: base_params.contig,
580583
inline_array: base_params.inline_array,
581584
chunks: base_params.chunks,
582-
):
585+
) -> xr.Dataset:
583586
if contig in self.virtual_contigs:
584587
contigs = self.virtual_contigs[contig]
585588
datasets = []
@@ -652,7 +655,7 @@ def snp_variants(
652655
site_mask: Optional[base_params.site_mask] = None,
653656
inline_array: base_params.inline_array = base_params.inline_array_default,
654657
chunks: base_params.chunks = base_params.native_chunks,
655-
):
658+
) -> xr.Dataset:
656659
# Normalise parameters.
657660
regions: List[Region] = _parse_multi_region(self, region)
658661
del region
@@ -789,7 +792,7 @@ def _locate_site_class(
789792
site_class: base_params.site_class,
790793
inline_array: base_params.inline_array = base_params.inline_array_default,
791794
chunks: base_params.chunks = base_params.native_chunks,
792-
):
795+
) -> np.ndarray:
793796
# Cache these data in memory to avoid repeated computation.
794797
cache_key = (region, site_mask, site_class)
795798

@@ -1111,12 +1114,12 @@ def _raw_snp_calls(
11111114
self,
11121115
*,
11131116
regions: Tuple[Region, ...],
1114-
sample_sets,
1115-
site_mask,
1116-
site_class,
1117-
inline_array,
1118-
chunks,
1119-
):
1117+
sample_sets: Optional[Tuple[str, ...]],
1118+
site_mask: Optional[str],
1119+
site_class: Optional[base_params.site_class],
1120+
inline_array: base_params.inline_array,
1121+
chunks: base_params.chunks,
1122+
) -> xr.Dataset:
11201123
# Access SNP calls and concatenate multiple sample sets and/or regions.
11211124
with self._spinner("Access SNP calls"):
11221125
lx = []
@@ -1180,17 +1183,17 @@ def _snp_calls(
11801183
self,
11811184
*,
11821185
regions: Tuple[Region, ...],
1183-
sample_sets,
1184-
sample_indices,
1185-
site_mask,
1186-
site_class,
1187-
cohort_size,
1188-
min_cohort_size,
1189-
max_cohort_size,
1190-
random_seed,
1191-
inline_array,
1192-
chunks,
1193-
):
1186+
sample_sets: Optional[Tuple[str, ...]],
1187+
sample_indices: Optional[Tuple[int, ...]],
1188+
site_mask: Optional[str],
1189+
site_class: Optional[base_params.site_class],
1190+
cohort_size: Optional[base_params.cohort_size],
1191+
min_cohort_size: Optional[base_params.min_cohort_size],
1192+
max_cohort_size: Optional[base_params.max_cohort_size],
1193+
random_seed: base_params.random_seed,
1194+
inline_array: base_params.inline_array,
1195+
chunks: base_params.chunks,
1196+
) -> xr.Dataset:
11941197
## Get SNP calls and concatenate multiple sample sets and/or regions.
11951198

11961199
# Note: sample_sets should be "prepared" before being passed to this private function.
@@ -1305,18 +1308,18 @@ def _results_cache_add_analysis_params(self, params: dict):
13051308
def _snp_allele_counts(
13061309
self,
13071310
*,
1308-
region,
1309-
sample_sets,
1310-
sample_indices,
1311-
site_mask,
1312-
site_class,
1313-
cohort_size,
1314-
min_cohort_size,
1315-
max_cohort_size,
1316-
random_seed,
1317-
inline_array,
1318-
chunks,
1319-
):
1311+
region: Union[dict, List[dict]],
1312+
sample_sets: Optional[Tuple[str, ...]],
1313+
sample_indices: Optional[Tuple[int, ...]],
1314+
site_mask: Optional[str],
1315+
site_class: Optional[base_params.site_class],
1316+
cohort_size: Optional[base_params.cohort_size],
1317+
min_cohort_size: Optional[base_params.min_cohort_size],
1318+
max_cohort_size: Optional[base_params.max_cohort_size],
1319+
random_seed: base_params.random_seed,
1320+
inline_array: base_params.inline_array,
1321+
chunks: base_params.chunks,
1322+
) -> Dict[str, np.ndarray]:
13201323
# Access SNP calls.
13211324
ds_snps = self.snp_calls(
13221325
region=region,
@@ -2007,22 +2010,22 @@ def biallelic_diplotypes(
20072010
def _biallelic_diplotypes(
20082011
self,
20092012
*,
2010-
region,
2011-
sample_sets,
2012-
sample_indices,
2013-
site_mask,
2014-
site_class,
2015-
cohort_size,
2016-
min_cohort_size,
2017-
max_cohort_size,
2018-
random_seed,
2019-
max_missing_an,
2020-
min_minor_ac,
2021-
n_snps,
2022-
thin_offset,
2023-
inline_array,
2024-
chunks,
2025-
):
2013+
region: Union[dict, List[dict]],
2014+
sample_sets: Optional[Tuple[str, ...]],
2015+
sample_indices: Optional[Tuple[int, ...]],
2016+
site_mask: Optional[str],
2017+
site_class: Optional[base_params.site_class],
2018+
cohort_size: Optional[base_params.cohort_size],
2019+
min_cohort_size: Optional[base_params.min_cohort_size],
2020+
max_cohort_size: Optional[base_params.max_cohort_size],
2021+
random_seed: base_params.random_seed,
2022+
max_missing_an: Optional[base_params.max_missing_an],
2023+
min_minor_ac: Optional[base_params.min_minor_ac],
2024+
n_snps: Optional[base_params.n_snps],
2025+
thin_offset: base_params.thin_offset,
2026+
inline_array: base_params.inline_array,
2027+
chunks: base_params.chunks,
2028+
) -> Dict[str, np.ndarray]:
20262029
# Note: this function uses sample_indices and should not expect a sample_query.
20272030

20282031
# Access biallelic SNPs.

0 commit comments

Comments
 (0)