Skip to content

Commit 43198f3

Browse files
authored
Merge pull request #1040 from 31puneet/fix/issue-484-pydantic-validate-call
fix: migrate _check_types from typeguard to pydantic validate_call
2 parents b698fa9 + a9d3e0f commit 43198f3

File tree

4 files changed

+237
-36
lines changed

4 files changed

+237
-36
lines changed

malariagen_data/anoph/base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,10 @@ def read_files(
234234
paths: Iterable[str],
235235
on_error: Literal["raise", "omit", "return"] = "return",
236236
) -> Mapping[str, Union[bytes, Exception]]:
237+
# Pydantic validate_call with strict=True converts Iterable into a
238+
# generator, which can be exhausted. Convert to a tuple first.
239+
paths = tuple(paths)
240+
237241
# Check for any cached files.
238242
files = {
239243
path: data for path, data in self._cache_files.items() if path in paths

malariagen_data/util.py

Lines changed: 44 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
import numpy as np
2727
import pandas as pd
2828
import plotly.express as px # type: ignore
29-
import typeguard
3029
import xarray as xr
3130
import zarr # type: ignore
3231

@@ -1207,43 +1206,57 @@ def _check_colab_location(gcp_region: Optional[str]):
12071206

12081207

12091208
def _check_types(f):
1210-
"""Simple decorator to provide runtime checking of parameter types.
1209+
"""Decorator to provide runtime checking of parameter types.
12111210
1212-
N.B., the typeguard package does have a decorator function called
1213-
@typechecked which performs a similar purpose. However, the typeguard
1214-
decorator causes a memory leak and doesn't seem usable. Also, the
1215-
typeguard decorator performs runtime checking of all variables within
1216-
the function as well as the arguments and return values. We only want
1217-
checking of the arguments to help users provide correct inputs.
1211+
Uses Pydantic v2's validate_call() for parameter validation.
1212+
Validates input types in strict mode without coercing arguments.
12181213
12191214
"""
1215+
from pydantic import ConfigDict, ValidationError, validate_call
1216+
1217+
config = ConfigDict(strict=True, arbitrary_types_allowed=True)
1218+
1219+
try:
1220+
validated_f = validate_call(config=config, validate_return=False)(f)
1221+
except Exception as exc:
1222+
warnings.warn(
1223+
f"Could not apply validate_call to {f.__name__!r}: {exc}. "
1224+
"Type validation will be skipped for this function.",
1225+
stacklevel=2,
1226+
)
1227+
return f
12201228

12211229
@wraps(f)
1222-
def check_types_wrapper(*args, **kwargs):
1223-
type_hints = get_type_hints(f)
1224-
call_args = getcallargs(f, *args, **kwargs)
1225-
for k, t in type_hints.items():
1226-
if k in call_args:
1227-
v = call_args[k]
1228-
try:
1229-
typeguard.check_type(v, t)
1230-
except typeguard.TypeCheckError as e:
1231-
expected_type = humanize_type(t)
1232-
actual_type = humanize_type(type(v))
1233-
message = fill(
1234-
dedent(
1235-
f"""
1236-
Parameter {k!r} with value {v!r} in call to function {f.__name__!r} has incorrect type:
1237-
found {actual_type}, expected {expected_type}. See below for further information.
1238-
"""
1230+
def wrapper(*args, **kwargs):
1231+
try:
1232+
return validated_f(*args, **kwargs)
1233+
except ValidationError as e:
1234+
type_hints = get_type_hints(f)
1235+
call_args = getcallargs(f, *args, **kwargs)
1236+
errors = e.errors()
1237+
if errors:
1238+
err = errors[0]
1239+
loc = err.get("loc", ())
1240+
if loc:
1241+
k = str(loc[0])
1242+
v = call_args.get(k)
1243+
t = type_hints.get(k)
1244+
if t is not None:
1245+
expected_type = humanize_type(t)
1246+
actual_type = humanize_type(type(v))
1247+
message = fill(
1248+
dedent(
1249+
f"""\
1250+
Parameter {k!r} with value {v!r} in call to function {f.__name__!r} has incorrect type:
1251+
found {actual_type}, expected {expected_type}. See below for further information.
1252+
"""
1253+
)
12391254
)
1240-
)
1241-
message += f"\n\n{e}"
1242-
error = TypeError(message)
1243-
raise error from None
1244-
return f(*args, **kwargs)
1255+
message += f"\n\n{e}"
1256+
raise TypeError(message) from None
1257+
raise TypeError(str(e)) from None
12451258

1246-
return check_types_wrapper
1259+
return wrapper
12471260

12481261

12491262
@numba.njit

0 commit comments

Comments
 (0)