Skip to content

Commit d1ace55

Browse files
committed
fix: migrate _check_types from typeguard to pydantic TypeAdapter (#484)
1 parent 05337a5 commit d1ace55

3 files changed

Lines changed: 214 additions & 22 deletions

File tree

malariagen_data/util.py

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
import numpy as np
2727
import pandas as pd
2828
import plotly.express as px # type: ignore
29-
import typeguard
3029
import xarray as xr
3130
import zarr # type: ignore
3231

@@ -1156,28 +1155,38 @@ def _check_colab_location(gcp_region: Optional[str]):
11561155

11571156

11581157
def _check_types(f):
1159-
"""Simple decorator to provide runtime checking of parameter types.
1158+
"""Decorator to provide runtime checking of parameter types.
11601159
1161-
N.B., the typeguard package does have a decorator function called
1162-
@typechecked which performs a similar purpose. However, the typeguard
1163-
decorator causes a memory leak and doesn't seem usable. Also, the
1164-
typeguard decorator performs runtime checking of all variables within
1165-
the function as well as the arguments and return values. We only want
1166-
checking of the arguments to help users provide correct inputs.
1160+
Uses Pydantic v2's TypeAdapter for parameter validation.
1161+
Validates input types without coercing arguments — the original
1162+
function is called with the original unmodified arguments.
11671163
11681164
"""
1165+
from pydantic import ConfigDict, TypeAdapter, ValidationError
1166+
1167+
type_hints = get_type_hints(f)
1168+
config = ConfigDict(arbitrary_types_allowed=True)
1169+
1170+
# Build a TypeAdapter for each annotated parameter (skip 'return').
1171+
adapters: dict = {}
1172+
for k, t in type_hints.items():
1173+
if k == "return":
1174+
continue
1175+
try:
1176+
adapters[k] = TypeAdapter(t, config=config)
1177+
except Exception:
1178+
pass # Skip types pydantic cannot handle
11691179

11701180
@wraps(f)
1171-
def check_types_wrapper(*args, **kwargs):
1172-
type_hints = get_type_hints(f)
1181+
def wrapper(*args, **kwargs):
11731182
call_args = getcallargs(f, *args, **kwargs)
1174-
for k, t in type_hints.items():
1183+
for k, adapter in adapters.items():
11751184
if k in call_args:
11761185
v = call_args[k]
11771186
try:
1178-
typeguard.check_type(v, t)
1179-
except typeguard.TypeCheckError as e:
1180-
expected_type = humanize_type(t)
1187+
adapter.validate_python(v, strict=True)
1188+
except ValidationError as e:
1189+
expected_type = humanize_type(type_hints[k])
11811190
actual_type = humanize_type(type(v))
11821191
message = fill(
11831192
dedent(
@@ -1188,11 +1197,10 @@ def check_types_wrapper(*args, **kwargs):
11881197
)
11891198
)
11901199
message += f"\n\n{e}"
1191-
error = TypeError(message)
1192-
raise error from None
1200+
raise TypeError(message) from None
11931201
return f(*args, **kwargs)
11941202

1195-
return check_types_wrapper
1203+
return wrapper
11961204

11971205

11981206
@numba.njit

0 commit comments

Comments
 (0)