|
26 | 26 | import numpy as np |
27 | 27 | import pandas as pd |
28 | 28 | import plotly.express as px # type: ignore |
29 | | -import typeguard |
30 | 29 | import xarray as xr |
31 | 30 | import zarr # type: ignore |
32 | 31 |
|
@@ -1207,43 +1206,57 @@ def _check_colab_location(gcp_region: Optional[str]): |
1207 | 1206 |
|
1208 | 1207 |
|
1209 | 1208 | def _check_types(f): |
1210 | | - """Simple decorator to provide runtime checking of parameter types. |
| 1209 | + """Decorator to provide runtime checking of parameter types. |
1211 | 1210 |
|
1212 | | - N.B., the typeguard package does have a decorator function called |
1213 | | - @typechecked which performs a similar purpose. However, the typeguard |
1214 | | - decorator causes a memory leak and doesn't seem usable. Also, the |
1215 | | - typeguard decorator performs runtime checking of all variables within |
1216 | | - the function as well as the arguments and return values. We only want |
1217 | | - checking of the arguments to help users provide correct inputs. |
| 1211 | + Uses Pydantic v2's validate_call() for parameter validation. |
| 1212 | + Validates input types in strict mode without coercing arguments. |
1218 | 1213 |
|
1219 | 1214 | """ |
| 1215 | + from pydantic import ConfigDict, ValidationError, validate_call |
| 1216 | + |
| 1217 | + config = ConfigDict(strict=True, arbitrary_types_allowed=True) |
| 1218 | + |
| 1219 | + try: |
| 1220 | + validated_f = validate_call(config=config, validate_return=False)(f) |
| 1221 | + except Exception as exc: |
| 1222 | + warnings.warn( |
| 1223 | + f"Could not apply validate_call to {f.__name__!r}: {exc}. " |
| 1224 | + "Type validation will be skipped for this function.", |
| 1225 | + stacklevel=2, |
| 1226 | + ) |
| 1227 | + return f |
1220 | 1228 |
|
1221 | 1229 | @wraps(f) |
1222 | | - def check_types_wrapper(*args, **kwargs): |
1223 | | - type_hints = get_type_hints(f) |
1224 | | - call_args = getcallargs(f, *args, **kwargs) |
1225 | | - for k, t in type_hints.items(): |
1226 | | - if k in call_args: |
1227 | | - v = call_args[k] |
1228 | | - try: |
1229 | | - typeguard.check_type(v, t) |
1230 | | - except typeguard.TypeCheckError as e: |
1231 | | - expected_type = humanize_type(t) |
1232 | | - actual_type = humanize_type(type(v)) |
1233 | | - message = fill( |
1234 | | - dedent( |
1235 | | - f""" |
1236 | | - Parameter {k!r} with value {v!r} in call to function {f.__name__!r} has incorrect type: |
1237 | | - found {actual_type}, expected {expected_type}. See below for further information. |
1238 | | - """ |
| 1230 | + def wrapper(*args, **kwargs): |
| 1231 | + try: |
| 1232 | + return validated_f(*args, **kwargs) |
| 1233 | + except ValidationError as e: |
| 1234 | + type_hints = get_type_hints(f) |
| 1235 | + call_args = getcallargs(f, *args, **kwargs) |
| 1236 | + errors = e.errors() |
| 1237 | + if errors: |
| 1238 | + err = errors[0] |
| 1239 | + loc = err.get("loc", ()) |
| 1240 | + if loc: |
| 1241 | + k = str(loc[0]) |
| 1242 | + v = call_args.get(k) |
| 1243 | + t = type_hints.get(k) |
| 1244 | + if t is not None: |
| 1245 | + expected_type = humanize_type(t) |
| 1246 | + actual_type = humanize_type(type(v)) |
| 1247 | + message = fill( |
| 1248 | + dedent( |
| 1249 | + f"""\ |
| 1250 | + Parameter {k!r} with value {v!r} in call to function {f.__name__!r} has incorrect type: |
| 1251 | + found {actual_type}, expected {expected_type}. See below for further information. |
| 1252 | + """ |
| 1253 | + ) |
1239 | 1254 | ) |
1240 | | - ) |
1241 | | - message += f"\n\n{e}" |
1242 | | - error = TypeError(message) |
1243 | | - raise error from None |
1244 | | - return f(*args, **kwargs) |
| 1255 | + message += f"\n\n{e}" |
| 1256 | + raise TypeError(message) from None |
| 1257 | + raise TypeError(str(e)) from None |
1245 | 1258 |
|
1246 | | - return check_types_wrapper |
| 1259 | + return wrapper |
1247 | 1260 |
|
1248 | 1261 |
|
1249 | 1262 | @numba.njit |
|
0 commit comments