Skip to content

Commit a5a917c

Browse files
authored
Merge branch 'master' into fix/veff-genome-cache-memory-leak
2 parents e229ac1 + 8ef069c commit a5a917c

8 files changed

Lines changed: 373 additions & 20 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
.idea
22
.vscode
33
__pycache__
4+
.mypy_cache
45
*.pyc
56
dist
67
.venv/

malariagen_data/anoph/base_params.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,14 @@ def _validate_sample_selection_params(
189189
"Random seed used for reproducible down-sampling.",
190190
]
191191

192+
gene: TypeAlias = Annotated[
193+
str,
194+
"""
195+
Gene identifier. Can be either a gene ID or gene name.
196+
Gene names are matched case-insensitively.
197+
""",
198+
]
199+
192200
transcript: TypeAlias = Annotated[
193201
str,
194202
"Gene transcript identifier.",

malariagen_data/anoph/cnv_frq.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,11 @@ def _gene_cnv(
9090
inline_array,
9191
):
9292
# Sanity check.
93-
assert isinstance(region, Region)
93+
if not isinstance(region, Region):
94+
raise TypeError(
95+
f"Expected region to be a Region object, "
96+
f"got {type(region).__name__}: {region!r}"
97+
)
9498

9599
# Access genes within the region of interest.
96100
df_genome_features = self.genome_features(region=region)
@@ -260,7 +264,11 @@ def _gene_cnv_frequencies(
260264
debug = self._log.debug
261265

262266
debug("sanity check - this function is one region at a time")
263-
assert isinstance(region, Region)
267+
if not isinstance(region, Region):
268+
raise TypeError(
269+
f"Expected region to be a Region object, "
270+
f"got {type(region).__name__}: {region!r}"
271+
)
264272

265273
debug("get gene copy number data")
266274
ds_cnv = self.gene_cnv(
@@ -504,7 +512,11 @@ def _gene_cnv_frequencies_advanced(
504512
debug = self._log.debug
505513

506514
debug("sanity check - here we deal with one region only")
507-
assert isinstance(region, Region)
515+
if not isinstance(region, Region):
516+
raise TypeError(
517+
f"Expected region to be a Region object, "
518+
f"got {type(region).__name__}: {region!r}"
519+
)
508520

509521
debug("access gene CNV calls")
510522
ds_cnv = self.gene_cnv(

malariagen_data/anoph/frq_base.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,11 @@ def plot_frequencies_heatmap(
341341
for j in range(1, idx_vals.shape[1]):
342342
index_col = index_col + ", " + idx_vals[:, j]
343343
else:
344-
assert isinstance(index, str)
344+
if not isinstance(index, str):
345+
raise TypeError(
346+
f"Expected index to be str or list, "
347+
f"got {type(index).__name__}: {index!r}"
348+
)
345349
index_col = df[index].astype(str)
346350

347351
# Check that index is unique.
@@ -413,6 +417,14 @@ def plot_frequencies_heatmap(
413417
`aa_allele_frequencies_advanced()` or
414418
`gene_cnv_frequencies_advanced()`.
415419
""",
420+
taxa="""
421+
Taxon or list of taxa to include in the plot. If None,
422+
all taxa are shown.
423+
""",
424+
areas="""
425+
Area or list of areas to include in the plot. If None,
426+
all areas are shown.
427+
""",
416428
kwargs="Passed through to `px.line()`.",
417429
),
418430
returns="""
@@ -588,7 +600,11 @@ def plot_frequencies_map_markers(
588600
ds_variant = ds.isel(variants=variant)
589601
variant_label = ds["variant_label"].values[variant]
590602
else:
591-
assert isinstance(variant, str)
603+
if not isinstance(variant, str):
604+
raise TypeError(
605+
f"Expected variant to be int or str, "
606+
f"got {type(variant).__name__}: {variant!r}"
607+
)
592608
ds_variant = ds.set_index(variants="variant_label").sel(variants=variant)
593609
variant_label = variant
594610

malariagen_data/anoph/genome_features.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,140 @@ def plot_transcript(
314314
bokeh.plotting.show(fig)
315315
return fig
316316

317+
@_check_types
318+
@doc(
319+
summary="Get the canonical transcript for a gene.",
320+
returns="""
321+
The transcript ID for the canonical transcript of the specified gene.
322+
The canonical transcript is the one with the highest number of
323+
transcribed base pairs (sum of exon lengths).
324+
""",
325+
)
326+
def canonical_transcript(
327+
self,
328+
gene: base_params.gene,
329+
) -> str:
330+
"""
331+
Parameters
332+
----------
333+
gene : str
334+
A gene identifier. Can be either a gene ID or gene name.
335+
336+
Returns
337+
-------
338+
str
339+
The transcript ID of the canonical transcript.
340+
341+
Raises
342+
------
343+
ValueError
344+
If the gene identifier is not found or if the gene has no transcripts.
345+
346+
Examples
347+
--------
348+
Get the canonical transcript for a gene by ID:
349+
350+
>>> import malariagen_data
351+
>>> ag3 = malariagen_data.ag3(pre=False)
352+
>>> canonical = ag3.canonical_transcript("AGAP004707")
353+
354+
Get the canonical transcript for a gene by name:
355+
356+
>>> canonical = ag3.canonical_transcript("Pvr")
357+
"""
358+
debug = self._log.debug
359+
debug(f"Looking up canonical transcript for gene '{gene}'")
360+
361+
# Load genome features once with required attributes
362+
with self._spinner(desc="Load gene data"):
363+
# Load required attributes (ordered for consistency with GFF3)
364+
attributes = ("ID", "Parent", self._gff_gene_name_attribute)
365+
df_features = self.genome_features(attributes=attributes)
366+
debug(f"Loaded {len(df_features)} genome features")
367+
368+
# Filter for genes
369+
df_genes = df_features[df_features["type"] == self._gff_gene_type]
370+
name_attr = self._gff_gene_name_attribute
371+
372+
# Normalize input: strip whitespace
373+
gene_normalized = gene.strip()
374+
375+
# Reject empty identifiers after normalization to avoid ambiguous matches
376+
if not gene_normalized:
377+
raise ValueError(
378+
"Gene identifier is empty after stripping whitespace; please provide a valid gene ID or name."
379+
)
380+
# Try exact ID match first (case-sensitive)
381+
debug(f"Attempting ID match for '{gene_normalized}'")
382+
gene_id_match = df_genes[df_genes["ID"].str.strip() == gene_normalized]
383+
384+
if len(gene_id_match) == 1:
385+
gene_id = gene_id_match.iloc[0]["ID"]
386+
debug(f"Found ID match: {gene_id}")
387+
elif len(gene_id_match) > 1:
388+
# This should not happen (ID should be unique), but handling gracefully
389+
raise ValueError(
390+
f"Multiple features with ID '{gene}' found (data integrity issue)"
391+
)
392+
else:
393+
# Trying name match (case-insensitive with whitespace handling)
394+
debug("No ID match, attempting name match")
395+
gene_name_match = df_genes[
396+
df_genes[name_attr].fillna("").str.strip().str.lower()
397+
== gene_normalized.lower()
398+
]
399+
400+
if len(gene_name_match) == 0:
401+
raise ValueError(f"Gene '{gene}' not found (no matching ID or name)")
402+
elif len(gene_name_match) > 1:
403+
# Suggest which genes matched for better debugging
404+
matching_ids = ", ".join(gene_name_match["ID"].values)
405+
raise ValueError(
406+
f"Gene name '{gene}' is ambiguous (matches {len(gene_name_match)} genes: {matching_ids}). "
407+
f"Please use a specific gene ID instead."
408+
)
409+
410+
gene_id = gene_name_match.iloc[0]["ID"]
411+
debug(f"Found name match: {gene_id}")
412+
413+
# Get transcripts for the gene
414+
debug(f"Finding transcripts for gene '{gene_id}'")
415+
df_transcripts = self.genome_feature_children(
416+
parent=gene_id, attributes=("ID",)
417+
)
418+
df_transcripts = df_transcripts[df_transcripts["type"] == "mRNA"]
419+
420+
if len(df_transcripts) == 0:
421+
raise ValueError(f"Gene '{gene}' has no transcripts")
422+
423+
debug(f"Found {len(df_transcripts)} transcripts for gene {gene_id}")
424+
425+
# Calculate transcript lengths and find canonical
426+
debug("Calculating transcript lengths for each transcript")
427+
transcript_lengths = {}
428+
429+
for transcript_id in df_transcripts["ID"]:
430+
# Get all exon children (genome_feature_children handles multi-parent exons)
431+
df_exons = self.genome_feature_children(
432+
parent=transcript_id, attributes=None
433+
)
434+
# Filter for exons only (important: exclude other feature types)
435+
df_exons = df_exons[df_exons["type"] == "exon"].sort_values("start")
436+
437+
if len(df_exons) > 0:
438+
# Calculate total transcribed length (1-based inclusive coordinates)
439+
exon_lengths = (df_exons["end"] - df_exons["start"] + 1).sum()
440+
transcript_lengths[transcript_id] = exon_lengths
441+
debug(f" {transcript_id}: {len(df_exons)} exons, {exon_lengths} bp")
442+
if not transcript_lengths:
443+
raise ValueError(f"Gene '{gene}' has no transcripts with exons")
444+
445+
# Find canonical (maximum length)
446+
canonical = max(transcript_lengths, key=lambda tid: transcript_lengths[tid])
447+
canonical_length = transcript_lengths[canonical]
448+
debug(f"Canonical transcript: {canonical} with {canonical_length} bp")
449+
return canonical
450+
317451
@_check_types
318452
@doc(
319453
summary="Plot a genes track, using bokeh.",

malariagen_data/anoph/sample_metadata.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1235,7 +1235,11 @@ def lookup_sample(
12351235
if isinstance(sample, str):
12361236
sample_rec = df_samples.loc[sample]
12371237
else:
1238-
assert isinstance(sample, int)
1238+
if not isinstance(sample, int):
1239+
raise TypeError(
1240+
f"Expected sample to be str or int, "
1241+
f"got {type(sample).__name__}: {sample!r}"
1242+
)
12391243
sample_rec = df_samples.iloc[sample]
12401244
return sample_rec
12411245

@@ -1348,7 +1352,11 @@ def _setup_sample_symbol(
13481352

13491353
else:
13501354
# Custom grouping using queries.
1351-
assert isinstance(symbol, Mapping)
1355+
if not isinstance(symbol, Mapping):
1356+
raise TypeError(
1357+
f"Expected symbol to be str or Mapping, "
1358+
f"got {type(symbol).__name__}: {symbol!r}"
1359+
)
13521360
data["symbol"] = ""
13531361
for key, value in symbol.items():
13541362
data.loc[data.query(value).index, "symbol"] = key
@@ -1397,7 +1405,11 @@ def _setup_sample_colors_plotly(
13971405

13981406
else:
13991407
# Custom grouping using queries.
1400-
assert isinstance(color, Mapping)
1408+
if not isinstance(color, Mapping):
1409+
raise TypeError(
1410+
f"Expected color to be str or Mapping, "
1411+
f"got {type(color).__name__}: {color!r}"
1412+
)
14011413
data["color"] = ""
14021414
for key, value in color.items():
14031415
data.loc[data.query(value).index, "color"] = key
@@ -1493,13 +1505,17 @@ def _setup_cohort_queries(
14931505
"""Convenience function to normalise the `cohorts` parameter to a
14941506
dictionary mapping cohort labels to sample metadata queries."""
14951507

1496-
if isinstance(cohorts, dict):
1508+
if isinstance(cohorts, Mapping):
14971509
# User has supplied a custom dictionary mapping cohort identifiers
14981510
# to pandas queries.
14991511
cohort_queries = cohorts
15001512

15011513
else:
1502-
assert isinstance(cohorts, str)
1514+
if not isinstance(cohorts, str):
1515+
raise TypeError(
1516+
f"Expected cohorts to be Mapping or str, "
1517+
f"got {type(cohorts).__name__}: {cohorts!r}"
1518+
)
15031519
# User has supplied a column in the sample metadata.
15041520
df_samples = self.sample_metadata(
15051521
sample_sets=sample_sets,
@@ -1855,7 +1871,11 @@ def _locate_cohorts(*, cohorts, data, min_cohort_size):
18551871
coh_dict[coh] = loc_coh
18561872

18571873
else:
1858-
assert isinstance(cohorts, str)
1874+
if not isinstance(cohorts, str):
1875+
raise TypeError(
1876+
f"Expected cohorts to be Mapping or str, "
1877+
f"got {type(cohorts).__name__}: {cohorts!r}"
1878+
)
18591879
# User has supplied the name of a sample metadata column.
18601880

18611881
# Convenience to allow things like "admin1_year" instead of "cohort_admin1_year".

0 commit comments

Comments
 (0)