@@ -395,6 +395,108 @@ def _sample_count_het(
395395
396396 return sample_id , sample_set , windows , counts
397397
398+ def cohort_count_het (
399+ self ,
400+ region : Region ,
401+ df_cohort_samples : pd .DataFrame ,
402+ sample_sets : Optional [base_params .sample_sets ],
403+ window_size : het_params .window_size ,
404+ site_mask : Optional [base_params .site_mask ],
405+ chunks : base_params .chunks ,
406+ inline_array : base_params .inline_array ,
407+ ):
408+ """Compute windowed heterozygosity counts for multiple samples in a cohort.
409+
410+ This method efficiently computes heterozygosity for all samples by loading
411+ SNP data once and computing across all samples, rather than calling snp_calls()
412+ repeatedly for each sample. This vectorized approach provides substantial
413+ performance improvements for large cohorts.
414+
415+ Parameters
416+ ----------
417+ region : Region
418+ Genome region to analyze.
419+ df_cohort_samples : pd.DataFrame
420+ Sample metadata dataframe with at least 'sample_id' column.
421+ sample_sets : str, optional
422+ Sample set identifier(s).
423+ window_size : int
424+ Size of sliding windows for heterozygosity computation.
425+ site_mask : str, optional
426+ Site mask to apply.
427+ chunks : str or int, dict
428+ Chunk size for dask arrays.
429+ inline_array : bool
430+ Whether to inline arrays.
431+
432+ Returns
433+ -------
434+ dict
435+ Mapping from sample_id to (windows, counts) tuple, where:
436+ - windows: array of shape (n_windows, 2) with [start, stop] positions
437+ - counts: array of shape (n_windows,) with heterozygous site counts per window
438+ """
439+ debug = self ._log .debug
440+
441+ # Extract sample IDs from cohort dataframe
442+ sample_ids = df_cohort_samples ["sample_id" ].values
443+
444+ debug ("access SNPs for all cohort samples" )
445+ # Load SNP data once for all samples in cohort
446+ ds_snps = self .snp_calls (
447+ region = region ,
448+ sample_sets = sample_sets ,
449+ site_mask = site_mask ,
450+ chunks = chunks ,
451+ inline_array = inline_array ,
452+ )
453+
454+ # Subset to cohort samples to ensure correct indexing
455+ ds_snps = ds_snps .set_index (samples = "sample_id" ).sel (samples = sample_ids )
456+ sample_id_to_idx = {sid : idx for idx , sid in enumerate (sample_ids )}
457+
458+ # SNP positions (same for all samples)
459+ pos = ds_snps ["variant_position" ].values
460+
461+ # guard against window_size exceeding available sites
462+ if pos .shape [0 ] < window_size :
463+ raise ValueError (
464+ f"Not enough sites ({ pos .shape [0 ]} ) for window size "
465+ f"({ window_size } ). Please reduce the window size or "
466+ f"use different site selection criteria."
467+ )
468+
469+ # Compute window coordinates once (same for all samples)
470+ windows = allel .moving_statistic (
471+ values = pos ,
472+ statistic = lambda x : [x [0 ], x [- 1 ]],
473+ size = window_size ,
474+ )
475+
476+ # access genotypes for all samples
477+ gt_data = ds_snps ["call_genotype" ].data
478+
479+ # Compute windowed heterozygosity for each sample and cache results
480+ results = {}
481+ for sample_id , sample_idx in sample_id_to_idx .items ():
482+ # Compute heterozygous genotypes for this sample only to avoid
483+ # materializing the full (variants, samples) array in memory.
484+ debug (f"Compute heterozygous genotypes for sample { sample_id } " )
485+ gt_sample = allel .GenotypeDaskVector (gt_data [:, sample_idx , :])
486+ with self ._dask_progress (desc = "Compute heterozygous genotypes" ):
487+ is_het_sample = gt_sample .is_het ().compute ()
488+
489+ # compute windowed heterozygosity for this sample
490+ counts = allel .moving_statistic (
491+ values = is_het_sample ,
492+ statistic = np .sum ,
493+ size = window_size ,
494+ )
495+
496+ results [sample_id ] = (windows , counts )
497+
498+ return results
499+
398500 @property
399501 def _roh_hmm_cache_name (self ):
400502 return "roh_hmm_v1"
@@ -816,18 +918,25 @@ def cohort_heterozygosity(
816918 )
817919 n_samples = len (df_cohort_samples )
818920
819- # Compute heterozygosity for each sample and take the mean.
921+ # Compute heterozygosity for all samples in the cohort using cohort_count_het().
922+ # This public method loads SNP data once and computes across all samples,
923+ # providing substantial speedup over sequential per-sample processing.
924+ cohort_het_results = self .cohort_count_het (
925+ region = region_prepped ,
926+ df_cohort_samples = df_cohort_samples ,
927+ sample_sets = sample_sets ,
928+ window_size = window_size ,
929+ site_mask = site_mask ,
930+ chunks = chunks ,
931+ inline_array = inline_array ,
932+ )
933+
934+ # Compute per-sample means and aggregate.
820935 het_values = []
821936 for sample_id in df_cohort_samples ["sample_id" ]:
822- df_het = self .sample_count_het (
823- sample = sample_id ,
824- region = region_prepped ,
825- window_size = window_size ,
826- site_mask = site_mask ,
827- chunks = chunks ,
828- inline_array = inline_array ,
829- )
830- het_values .append (df_het ["heterozygosity" ].mean ())
937+ _ , counts = cohort_het_results [sample_id ]
938+ het_mean = np .mean (counts / window_size )
939+ het_values .append (het_mean )
831940
832941 results .append (
833942 {
0 commit comments