@@ -371,6 +371,10 @@ def cohort_diversity_stats(
371371 ) -> pd .Series :
372372 debug = self ._log .debug
373373
374+ # Change this name if you ever change the behaviour of this function, to
375+ # invalidate any previously cached data.
376+ name = "cohort_diversity_stats_v1"
377+
374378 debug ("process cohort parameter" )
375379 cohort_query = None
376380 if isinstance (cohort , str ):
@@ -391,28 +395,59 @@ def cohort_diversity_stats(
391395 else :
392396 raise TypeError (f"invalid cohort parameter: { cohort !r} " )
393397
394- debug ("access allele counts" )
395- ac = self .snp_allele_counts (
398+ params = dict (
399+ cohort_label = cohort_label ,
400+ cohort_query = cohort_query ,
401+ cohort_size = cohort_size ,
396402 region = region ,
403+ min_cohort_size = min_cohort_size ,
404+ max_cohort_size = max_cohort_size ,
397405 site_mask = site_mask ,
398406 site_class = site_class ,
399- sample_query = cohort_query ,
400407 sample_sets = sample_sets ,
401- cohort_size = cohort_size ,
402- min_cohort_size = min_cohort_size ,
403- max_cohort_size = max_cohort_size ,
404408 random_seed = random_seed ,
409+ n_jack = n_jack ,
410+ confidence_level = confidence_level ,
405411 chunks = chunks ,
406412 inline_array = inline_array ,
407413 )
408414
409- debug ("compute diversity stats" )
410- stats = self ._block_jackknife_cohort_diversity_stats (
411- cohort_label = cohort_label ,
412- ac = ac ,
413- n_jack = n_jack ,
414- confidence_level = confidence_level ,
415- )
415+ # Try to retrieve results from the cache.
416+ try :
417+ results = self .results_cache_get (name = name , params = params )
418+ stats = {
419+ key : value .item ()
420+ if isinstance (value , np .ndarray ) and value .shape == ()
421+ else value
422+ for key , value in results .items ()
423+ }
424+
425+ except CacheMiss :
426+ debug ("access allele counts" )
427+ ac = self .snp_allele_counts (
428+ region = region ,
429+ site_mask = site_mask ,
430+ site_class = site_class ,
431+ sample_query = cohort_query ,
432+ sample_sets = sample_sets ,
433+ cohort_size = cohort_size ,
434+ min_cohort_size = min_cohort_size ,
435+ max_cohort_size = max_cohort_size ,
436+ random_seed = random_seed ,
437+ chunks = chunks ,
438+ inline_array = inline_array ,
439+ )
440+
441+ debug ("compute diversity stats" )
442+ stats = self ._block_jackknife_cohort_diversity_stats (
443+ cohort_label = cohort_label ,
444+ ac = ac ,
445+ n_jack = n_jack ,
446+ confidence_level = confidence_level ,
447+ )
448+
449+ cache_results = {key : np .asarray (value ) for key , value in stats .items ()}
450+ self .results_cache_set (name = name , params = params , results = cache_results )
416451
417452 debug ("compute some extra cohort variables" )
418453 df_samples = self .sample_metadata (
0 commit comments