@@ -395,6 +395,105 @@ def _sample_count_het(
395395
396396 return sample_id , sample_set , windows , counts
397397
398+ def _cohort_count_het_vectorized (
399+ self ,
400+ region : Region ,
401+ df_cohort_samples : pd .DataFrame ,
402+ sample_sets : Optional [base_params .sample_sets ],
403+ window_size : het_params .window_size ,
404+ site_mask : Optional [base_params .site_mask ],
405+ chunks : base_params .chunks ,
406+ inline_array : base_params .inline_array ,
407+ ):
408+ """Vectorized computation of windowed heterozygosity for multiple samples.
409+
410+ Loads SNP data once for all cohort samples, then computes heterozygosity
411+ across all samples efficiently, rather than calling snp_calls() repeatedly
412+ for each sample.
413+
414+ Parameters
415+ ----------
416+ region : Region
417+ Genome region to analyze.
418+ df_cohort_samples : pd.DataFrame
419+ Sample metadata dataframe with at least 'sample_id' column.
420+ sample_sets : str, optional
421+ Sample set identifier(s).
422+ window_size : int
423+ Size of sliding windows for heterozygosity computation.
424+ site_mask : str, optional
425+ Site mask to apply.
426+ chunks : str or int, dict
427+ Chunk size for dask arrays.
428+ inline_array : bool
429+ Whether to inline arrays.
430+
431+ Returns
432+ -------
433+ dict
434+ Mapping from sample_id to (windows, counts) tuple, where:
435+ - windows: array of shape (n_windows, 2) with [start, stop] positions
436+ - counts: array of shape (n_windows,) with heterozygous site counts per window
437+ """
438+ debug = self ._log .debug
439+
440+ # Extract sample IDs from cohort dataframe
441+ sample_ids = df_cohort_samples ["sample_id" ].values
442+ sample_id_to_idx = {sid : idx for idx , sid in enumerate (sample_ids )}
443+
444+ debug ("access SNPs for all cohort samples" )
445+ # Load SNP data once for all samples in cohort
446+ ds_snps = self .snp_calls (
447+ region = region ,
448+ sample_sets = sample_sets ,
449+ site_mask = site_mask ,
450+ chunks = chunks ,
451+ inline_array = inline_array ,
452+ )
453+
454+ # SNP positions (same for all samples)
455+ pos = ds_snps ["variant_position" ].values
456+
457+ # guard against window_size exceeding available sites
458+ if pos .shape [0 ] < window_size :
459+ raise ValueError (
460+ f"Not enough sites ({ pos .shape [0 ]} ) for window size "
461+ f"({ window_size } ). Please reduce the window size or "
462+ f"use different site selection criteria."
463+ )
464+
465+ # Compute window coordinates once (same for all samples)
466+ windows = allel .moving_statistic (
467+ values = pos ,
468+ statistic = lambda x : [x [0 ], x [- 1 ]],
469+ size = window_size ,
470+ )
471+
472+ # access genotypes for all samples
473+ gt = allel .GenotypeDaskArray (ds_snps ["call_genotype" ].data )
474+
475+ # compute het across all samples: shape (variants, samples)
476+ debug ("Compute heterozygous genotypes for all samples" )
477+ with self ._dask_progress (desc = "Compute heterozygous genotypes" ):
478+ is_het_all = gt .is_het ().compute ()
479+
480+ # Compute windowed heterozygosity for each sample and cache results
481+ results = {}
482+ for sample_id , sample_idx in sample_id_to_idx .items ():
483+ # Extract heterozygosity column for this sample
484+ is_het_sample = is_het_all [:, sample_idx ]
485+
486+ # compute windowed heterozygosity for this sample
487+ counts = allel .moving_statistic (
488+ values = is_het_sample ,
489+ statistic = np .sum ,
490+ size = window_size ,
491+ )
492+
493+ results [sample_id ] = (windows , counts )
494+
495+ return results
496+
398497 @property
399498 def _roh_hmm_cache_name (self ):
400499 return "roh_hmm_v1"
@@ -795,18 +894,25 @@ def cohort_heterozygosity(
795894 )
796895 n_samples = len (df_cohort_samples )
797896
798- # Compute heterozygosity for each sample and take the mean.
897+ # Compute heterozygosity for all samples in the cohort using vectorized method.
898+ # This loads SNP data once and computes heterozygosity across all samples,
899+ # yielding substantial speedup over sequential per-sample processing.
900+ cohort_het_results = self ._cohort_count_het_vectorized (
901+ region = region_prepped ,
902+ df_cohort_samples = df_cohort_samples ,
903+ sample_sets = sample_sets ,
904+ window_size = window_size ,
905+ site_mask = site_mask ,
906+ chunks = chunks ,
907+ inline_array = inline_array ,
908+ )
909+
910+ # Compute per-sample means and aggregate.
799911 het_values = []
800912 for sample_id in df_cohort_samples ["sample_id" ]:
801- df_het = self .sample_count_het (
802- sample = sample_id ,
803- region = region_prepped ,
804- window_size = window_size ,
805- site_mask = site_mask ,
806- chunks = chunks ,
807- inline_array = inline_array ,
808- )
809- het_values .append (df_het ["heterozygosity" ].mean ())
913+ windows , counts = cohort_het_results [sample_id ]
914+ het_mean = np .mean (counts / window_size )
915+ het_values .append (het_mean )
810916
811917 results .append (
812918 {
0 commit comments