@@ -67,12 +67,15 @@ def __init__(
6767 self ._default_site_mask = default_site_mask
6868
6969 # Set up caches.
70- # TODO review type annotations here, maybe can tighten
71- self ._cache_snp_sites = None
72- self ._cache_snp_genotypes : Dict = dict ()
73- self ._cache_site_filters : Dict = dict ()
74- self ._cache_site_annotations = None
75- self ._cache_locate_site_class : Dict = dict ()
70+ self ._cache_snp_sites : Optional [zarr .hierarchy .Group ] = None
71+ self ._cache_snp_genotypes : Dict [
72+ base_params .sample_set , zarr .hierarchy .Group
73+ ] = dict ()
74+ self ._cache_site_filters : Dict [
75+ base_params .site_mask , zarr .hierarchy .Group
76+ ] = dict ()
77+ self ._cache_site_annotations : Optional [zarr .hierarchy .Group ] = None
78+ self ._cache_locate_site_class : Dict [Tuple [Any , ...], np .ndarray ] = dict ()
7679
7780 # Create the SNP-calls cache as a per-instance lru_cache wrapping the
7881 # bound method. Storing it on the instance (rather than using a
@@ -214,7 +217,7 @@ def _site_filters_for_contig(
214217 field : base_params .field ,
215218 inline_array : base_params .inline_array ,
216219 chunks : base_params .chunks ,
217- ):
220+ ) -> da . Array :
218221 if contig in self .virtual_contigs :
219222 contigs = self .virtual_contigs [contig ]
220223 arrs = [
@@ -245,7 +248,7 @@ def _site_filters_for_region(
245248 field : base_params .field ,
246249 inline_array : base_params .inline_array ,
247250 chunks : base_params .chunks ,
248- ):
251+ ) -> da . Array :
249252 d = self ._site_filters_for_contig (
250253 contig = region .contig ,
251254 mask = mask ,
@@ -579,7 +582,7 @@ def _snp_variants_for_contig(
579582 contig : base_params .contig ,
580583 inline_array : base_params .inline_array ,
581584 chunks : base_params .chunks ,
582- ):
585+ ) -> xr . Dataset :
583586 if contig in self .virtual_contigs :
584587 contigs = self .virtual_contigs [contig ]
585588 datasets = []
@@ -652,7 +655,7 @@ def snp_variants(
652655 site_mask : Optional [base_params .site_mask ] = None ,
653656 inline_array : base_params .inline_array = base_params .inline_array_default ,
654657 chunks : base_params .chunks = base_params .native_chunks ,
655- ):
658+ ) -> xr . Dataset :
656659 # Normalise parameters.
657660 regions : List [Region ] = _parse_multi_region (self , region )
658661 del region
@@ -789,7 +792,7 @@ def _locate_site_class(
789792 site_class : base_params .site_class ,
790793 inline_array : base_params .inline_array = base_params .inline_array_default ,
791794 chunks : base_params .chunks = base_params .native_chunks ,
792- ):
795+ ) -> np . ndarray :
793796 # Cache these data in memory to avoid repeated computation.
794797 cache_key = (region , site_mask , site_class )
795798
@@ -1082,11 +1085,11 @@ def snp_calls(
10821085 del site_mask
10831086
10841087 # Convert lists to tuples to avoid CacheMiss "TypeError: unhashable type: 'list'".
1085- prepared_regions_tuple : Tuple [ Region , ...] = tuple (prepared_regions )
1086- prepared_sample_sets_tuple : Optional [Tuple [ str , ...] ] = (
1088+ prepared_regions_tuple : base_params . regions_tuple = tuple (prepared_regions )
1089+ prepared_sample_sets_tuple : Optional [base_params . sample_sets_tuple ] = (
10871090 tuple (prepared_sample_sets ) if prepared_sample_sets is not None else None
10881091 )
1089- prepared_sample_indices_tuple : Optional [Tuple [ int , ...] ] = (
1092+ prepared_sample_indices_tuple : Optional [base_params . sample_indices_tuple ] = (
10901093 tuple (prepared_sample_indices )
10911094 if prepared_sample_indices is not None
10921095 else None
@@ -1110,18 +1113,19 @@ def snp_calls(
11101113 def _raw_snp_calls (
11111114 self ,
11121115 * ,
1113- regions : Tuple [ Region , ...] ,
1114- sample_sets ,
1115- site_mask ,
1116- site_class ,
1117- inline_array ,
1118- chunks ,
1119- ):
1116+ regions : base_params . regions_tuple ,
1117+ sample_sets : Optional [ base_params . sample_sets_tuple ] ,
1118+ site_mask : Optional [ base_params . site_mask ] ,
1119+ site_class : Optional [ base_params . site_class ] ,
1120+ inline_array : base_params . inline_array ,
1121+ chunks : base_params . chunks ,
1122+ ) -> xr . Dataset :
11201123 # Access SNP calls and concatenate multiple sample sets and/or regions.
11211124 with self ._spinner ("Access SNP calls" ):
11221125 lx = []
11231126 for r in regions :
11241127 ly = []
1128+ assert sample_sets is not None
11251129 for s in sample_sets :
11261130 y = self ._snp_calls_for_contig (
11271131 contig = r .contig ,
@@ -1179,18 +1183,18 @@ def _raw_snp_calls(
11791183 def _snp_calls (
11801184 self ,
11811185 * ,
1182- regions : Tuple [ Region , ...] ,
1183- sample_sets ,
1184- sample_indices ,
1185- site_mask ,
1186- site_class ,
1187- cohort_size ,
1188- min_cohort_size ,
1189- max_cohort_size ,
1190- random_seed ,
1191- inline_array ,
1192- chunks ,
1193- ):
1186+ regions : base_params . regions_tuple ,
1187+ sample_sets : Optional [ base_params . sample_sets_tuple ] ,
1188+ sample_indices : Optional [ base_params . sample_indices_tuple ] ,
1189+ site_mask : Optional [ base_params . site_mask ] ,
1190+ site_class : Optional [ base_params . site_class ] ,
1191+ cohort_size : Optional [ base_params . cohort_size ] ,
1192+ min_cohort_size : Optional [ base_params . min_cohort_size ] ,
1193+ max_cohort_size : Optional [ base_params . max_cohort_size ] ,
1194+ random_seed : base_params . random_seed ,
1195+ inline_array : base_params . inline_array ,
1196+ chunks : base_params . chunks ,
1197+ ) -> xr . Dataset :
11941198 ## Get SNP calls and concatenate multiple sample sets and/or regions.
11951199
11961200 # Note: sample_sets should be "prepared" before being passed to this private function.
@@ -1305,23 +1309,25 @@ def _results_cache_add_analysis_params(self, params: dict):
13051309 def _snp_allele_counts (
13061310 self ,
13071311 * ,
1308- region ,
1309- sample_sets ,
1310- sample_indices ,
1311- site_mask ,
1312- site_class ,
1313- cohort_size ,
1314- min_cohort_size ,
1315- max_cohort_size ,
1316- random_seed ,
1317- inline_array ,
1318- chunks ,
1319- ):
1312+ region : Union [ dict , List [ dict ]] ,
1313+ sample_sets : Optional [ base_params . sample_sets_tuple ] ,
1314+ sample_indices : Optional [ base_params . sample_indices_tuple ] ,
1315+ site_mask : Optional [ base_params . site_mask ] ,
1316+ site_class : Optional [ base_params . site_class ] ,
1317+ cohort_size : Optional [ base_params . cohort_size ] ,
1318+ min_cohort_size : Optional [ base_params . min_cohort_size ] ,
1319+ max_cohort_size : Optional [ base_params . max_cohort_size ] ,
1320+ random_seed : base_params . random_seed ,
1321+ inline_array : base_params . inline_array ,
1322+ chunks : base_params . chunks ,
1323+ ) -> Dict [ str , np . ndarray ] :
13201324 # Access SNP calls.
1325+ # N.B., snp_calls is a public API with @_check_types, which expects
1326+ # List[int] for sample_indices, not a tuple. Convert back here.
13211327 ds_snps = self .snp_calls (
13221328 region = region ,
1323- sample_sets = sample_sets ,
1324- sample_indices = sample_indices ,
1329+ sample_sets = list ( sample_sets ) if sample_sets is not None else None ,
1330+ sample_indices = list ( sample_indices ) if sample_indices is not None else None ,
13251331 site_mask = site_mask ,
13261332 site_class = site_class ,
13271333 cohort_size = cohort_size ,
@@ -1402,6 +1408,15 @@ def snp_allele_counts(
14021408 sample_query_options = sample_query_options ,
14031409 sample_indices = sample_indices ,
14041410 )
1411+ # Convert lists to tuples to avoid CacheMiss "TypeError: unhashable type: 'list'".
1412+ sample_sets_prepped_tuple : Optional [base_params .sample_sets_tuple ] = (
1413+ tuple (sample_sets_prepped ) if sample_sets_prepped is not None else None
1414+ )
1415+ sample_indices_prepped_tuple : Optional [base_params .sample_indices_tuple ] = (
1416+ tuple (sample_indices_prepped )
1417+ if sample_indices_prepped is not None
1418+ else None
1419+ )
14051420 del sample_sets
14061421 del sample_query
14071422 del sample_query_options
@@ -1412,8 +1427,8 @@ def snp_allele_counts(
14121427 del site_mask
14131428 params = dict (
14141429 region = region_prepped ,
1415- sample_sets = sample_sets_prepped ,
1416- sample_indices = sample_indices_prepped ,
1430+ sample_sets = sample_sets_prepped_tuple ,
1431+ sample_indices = sample_indices_prepped_tuple ,
14171432 site_mask = site_mask_prepped ,
14181433 site_class = site_class ,
14191434 cohort_size = cohort_size ,
@@ -1427,7 +1442,17 @@ def snp_allele_counts(
14271442
14281443 except CacheMiss :
14291444 results = self ._snp_allele_counts (
1430- ** params , inline_array = inline_array , chunks = chunks
1445+ inline_array = inline_array ,
1446+ chunks = chunks ,
1447+ region = region_prepped ,
1448+ sample_sets = sample_sets_prepped_tuple ,
1449+ sample_indices = sample_indices_prepped_tuple ,
1450+ site_mask = site_mask_prepped ,
1451+ site_class = site_class ,
1452+ cohort_size = cohort_size ,
1453+ min_cohort_size = min_cohort_size ,
1454+ max_cohort_size = max_cohort_size ,
1455+ random_seed = random_seed ,
14311456 )
14321457 self .results_cache_set (name = name , params = params , results = results )
14331458
@@ -1964,6 +1989,16 @@ def biallelic_diplotypes(
19641989 prepared_region = self ._prep_region_cache_param (region = region )
19651990 prepared_site_mask = self ._prep_optional_site_mask_param (site_mask = site_mask )
19661991
1992+ # Convert lists to tuples to avoid CacheMiss "TypeError: unhashable type: 'list'".
1993+ prepared_sample_sets_tuple : Optional [base_params .sample_sets_tuple ] = (
1994+ tuple (prepared_sample_sets ) if prepared_sample_sets is not None else None
1995+ )
1996+ prepared_sample_indices_tuple : Optional [base_params .sample_indices_tuple ] = (
1997+ tuple (prepared_sample_indices )
1998+ if prepared_sample_indices is not None
1999+ else None
2000+ )
2001+
19672002 # Delete original parameters to prevent accidental use.
19682003 del sample_sets
19692004 del sample_query
@@ -1976,8 +2011,8 @@ def biallelic_diplotypes(
19762011 region = prepared_region ,
19772012 n_snps = n_snps ,
19782013 thin_offset = thin_offset ,
1979- sample_sets = prepared_sample_sets ,
1980- sample_indices = prepared_sample_indices ,
2014+ sample_sets = prepared_sample_sets_tuple ,
2015+ sample_indices = prepared_sample_indices_tuple ,
19812016 site_mask = prepared_site_mask ,
19822017 site_class = site_class ,
19832018 cohort_size = cohort_size ,
@@ -1994,7 +2029,21 @@ def biallelic_diplotypes(
19942029
19952030 except CacheMiss :
19962031 results = self ._biallelic_diplotypes (
1997- inline_array = inline_array , chunks = chunks , ** params
2032+ inline_array = inline_array ,
2033+ chunks = chunks ,
2034+ region = prepared_region ,
2035+ sample_sets = prepared_sample_sets_tuple ,
2036+ sample_indices = prepared_sample_indices_tuple ,
2037+ site_mask = prepared_site_mask ,
2038+ site_class = site_class ,
2039+ cohort_size = cohort_size ,
2040+ min_cohort_size = min_cohort_size ,
2041+ max_cohort_size = max_cohort_size ,
2042+ random_seed = random_seed ,
2043+ n_snps = n_snps ,
2044+ thin_offset = thin_offset ,
2045+ min_minor_ac = min_minor_ac ,
2046+ max_missing_an = max_missing_an ,
19982047 )
19992048 self .results_cache_set (name = name , params = params , results = results )
20002049
@@ -2007,29 +2056,31 @@ def biallelic_diplotypes(
20072056 def _biallelic_diplotypes (
20082057 self ,
20092058 * ,
2010- region ,
2011- sample_sets ,
2012- sample_indices ,
2013- site_mask ,
2014- site_class ,
2015- cohort_size ,
2016- min_cohort_size ,
2017- max_cohort_size ,
2018- random_seed ,
2019- max_missing_an ,
2020- min_minor_ac ,
2021- n_snps ,
2022- thin_offset ,
2023- inline_array ,
2024- chunks ,
2025- ):
2059+ region : Union [ dict , List [ dict ]] ,
2060+ sample_sets : Optional [ base_params . sample_sets_tuple ] ,
2061+ sample_indices : Optional [ base_params . sample_indices_tuple ] ,
2062+ site_mask : Optional [ base_params . site_mask ] ,
2063+ site_class : Optional [ base_params . site_class ] ,
2064+ cohort_size : Optional [ base_params . cohort_size ] ,
2065+ min_cohort_size : Optional [ base_params . min_cohort_size ] ,
2066+ max_cohort_size : Optional [ base_params . max_cohort_size ] ,
2067+ random_seed : base_params . random_seed ,
2068+ max_missing_an : Optional [ base_params . max_missing_an ] ,
2069+ min_minor_ac : Optional [ base_params . min_minor_ac ] ,
2070+ n_snps : Optional [ base_params . n_snps ] ,
2071+ thin_offset : base_params . thin_offset ,
2072+ inline_array : base_params . inline_array ,
2073+ chunks : base_params . chunks ,
2074+ ) -> Dict [ str , np . ndarray ] :
20262075 # Note: this function uses sample_indices and should not expect a sample_query.
20272076
20282077 # Access biallelic SNPs.
2078+ # N.B., biallelic_snp_calls is a public API with @_check_types, which
2079+ # expects List[int] for sample_indices, not a tuple. Convert back here.
20292080 ds = self .biallelic_snp_calls (
20302081 region = region ,
2031- sample_sets = sample_sets ,
2032- sample_indices = sample_indices ,
2082+ sample_sets = list ( sample_sets ) if sample_sets is not None else None ,
2083+ sample_indices = list ( sample_indices ) if sample_indices is not None else None ,
20332084 site_mask = site_mask ,
20342085 site_class = site_class ,
20352086 cohort_size = cohort_size ,
0 commit comments