@@ -439,7 +439,6 @@ def _cohort_count_het_vectorized(
439439
440440 # Extract sample IDs from cohort dataframe
441441 sample_ids = df_cohort_samples ["sample_id" ].values
442- sample_id_to_idx = {sid : idx for idx , sid in enumerate (sample_ids )}
443442
444443 debug ("access SNPs for all cohort samples" )
445444 # Load SNP data once for all samples in cohort
@@ -451,6 +450,10 @@ def _cohort_count_het_vectorized(
451450 inline_array = inline_array ,
452451 )
453452
453+ # Subset to cohort samples to ensure correct indexing
454+ ds_snps = ds_snps .set_index (samples = "sample_id" ).sel (samples = sample_ids )
455+ sample_id_to_idx = {sid : idx for idx , sid in enumerate (sample_ids )}
456+
454457 # SNP positions (same for all samples)
455458 pos = ds_snps ["variant_position" ].values
456459
@@ -470,18 +473,17 @@ def _cohort_count_het_vectorized(
470473 )
471474
472475 # access genotypes for all samples
473- gt = allel .GenotypeDaskArray (ds_snps ["call_genotype" ].data )
474-
475- # compute het across all samples: shape (variants, samples)
476- debug ("Compute heterozygous genotypes for all samples" )
477- with self ._dask_progress (desc = "Compute heterozygous genotypes" ):
478- is_het_all = gt .is_het ().compute ()
476+ gt_data = ds_snps ["call_genotype" ].data
479477
480478 # Compute windowed heterozygosity for each sample and cache results
481479 results = {}
482480 for sample_id , sample_idx in sample_id_to_idx .items ():
483- # Extract heterozygosity column for this sample
484- is_het_sample = is_het_all [:, sample_idx ]
481+ # Compute heterozygous genotypes for this sample only to avoid
482+ # materializing the full (variants, samples) array in memory.
483+ debug (f"Compute heterozygous genotypes for sample { sample_id } " )
484+ gt_sample = allel .GenotypeDaskVector (gt_data [:, sample_idx , :])
485+ with self ._dask_progress (desc = "Compute heterozygous genotypes" ):
486+ is_het_sample = gt_sample .is_het ().compute ()
485487
486488 # compute windowed heterozygosity for this sample
487489 counts = allel .moving_statistic (
@@ -910,7 +912,7 @@ def cohort_heterozygosity(
910912 # Compute per-sample means and aggregate.
911913 het_values = []
912914 for sample_id in df_cohort_samples ["sample_id" ]:
913- windows , counts = cohort_het_results [sample_id ]
915+ _ , counts = cohort_het_results [sample_id ]
914916 het_mean = np .mean (counts / window_size )
915917 het_values .append (het_mean )
916918
0 commit comments