|
| 1 | +import asyncio |
1 | 2 | import logging |
2 | 3 |
|
3 | 4 | import pytest |
@@ -247,6 +248,101 @@ def variable_hook(var_name): |
247 | 248 | ("height > 1 and height < 2", True, ["height"]), |
248 | 249 | ], |
249 | 250 | ) |
| 251 | +def async_variable_hook(var_name): |
| 252 | + """Variable hook that returns async callables, for testing issue #535.""" |
| 253 | + values = { |
| 254 | + "cond_true": True, |
| 255 | + "cond_false": False, |
| 256 | + "val_10": 10, |
| 257 | + "val_20": 20, |
| 258 | + } |
| 259 | + |
| 260 | + value = values.get(var_name, False) |
| 261 | + |
| 262 | + async def decorated(*args, **kwargs): |
| 263 | + await asyncio.sleep(0) |
| 264 | + return value |
| 265 | + |
| 266 | + decorated.__name__ = var_name |
| 267 | + return decorated |
| 268 | + |
| 269 | + |
| 270 | +@pytest.mark.parametrize( |
| 271 | + ("expression", "expected"), |
| 272 | + [ |
| 273 | + ("not cond_false", True), |
| 274 | + ("not cond_true", False), |
| 275 | + ("cond_true and cond_true", True), |
| 276 | + ("cond_true and cond_false", False), |
| 277 | + ("cond_false and cond_true", False), |
| 278 | + ("cond_false or cond_true", True), |
| 279 | + ("cond_true or cond_false", True), |
| 280 | + ("cond_false or cond_false", False), |
| 281 | + ("not cond_false and cond_true", True), |
| 282 | + ("not (cond_true and cond_false)", True), |
| 283 | + ("not (cond_false or cond_false)", True), |
| 284 | + ("cond_true and not cond_false", True), |
| 285 | + ("val_10 == 10", True), |
| 286 | + ("val_10 != 20", True), |
| 287 | + ("val_10 < val_20", True), |
| 288 | + ("val_20 > val_10", True), |
| 289 | + ("val_10 >= 10", True), |
| 290 | + ("val_10 <= val_20", True), |
| 291 | + ], |
| 292 | +) |
| 293 | +def test_async_expressions(expression, expected): |
| 294 | + """Issue #535: condition expressions with async predicates must await results.""" |
| 295 | + parsed_expr = parse_boolean_expr(expression, async_variable_hook, operator_mapping) |
| 296 | + result = parsed_expr() |
| 297 | + assert asyncio.iscoroutine(result), f"Expected coroutine for async expression: {expression}" |
| 298 | + assert asyncio.run(result) is expected, expression |
| 299 | + |
| 300 | + |
| 301 | +def mixed_variable_hook(var_name): |
| 302 | + """Variable hook where some vars are sync and some are async.""" |
| 303 | + sync_values = {"sync_true": True, "sync_false": False, "sync_10": 10} |
| 304 | + async_values = {"async_true": True, "async_false": False, "async_20": 20} |
| 305 | + |
| 306 | + if var_name in async_values: |
| 307 | + value = async_values[var_name] |
| 308 | + |
| 309 | + async def async_decorated(*args, **kwargs): |
| 310 | + await asyncio.sleep(0) |
| 311 | + return value |
| 312 | + |
| 313 | + async_decorated.__name__ = var_name |
| 314 | + return async_decorated |
| 315 | + |
| 316 | + def sync_decorated(*args, **kwargs): |
| 317 | + return sync_values.get(var_name, False) |
| 318 | + |
| 319 | + sync_decorated.__name__ = var_name |
| 320 | + return sync_decorated |
| 321 | + |
| 322 | + |
| 323 | +@pytest.mark.parametrize( |
| 324 | + ("expression", "expected"), |
| 325 | + [ |
| 326 | + # async left, sync right |
| 327 | + ("async_true and sync_true", True), |
| 328 | + ("async_false or sync_true", True), |
| 329 | + # sync left, async right |
| 330 | + ("sync_true and async_true", True), |
| 331 | + ("sync_false or async_true", True), |
| 332 | + ("sync_true and async_false", False), |
| 333 | + ("sync_false or async_false", False), |
| 334 | + ], |
| 335 | +) |
| 336 | +def test_mixed_sync_async_expressions(expression, expected): |
| 337 | + """Expressions mixing sync and async predicates must handle both correctly.""" |
| 338 | + parsed_expr = parse_boolean_expr(expression, mixed_variable_hook, operator_mapping) |
| 339 | + result = parsed_expr() |
| 340 | + if asyncio.iscoroutine(result): |
| 341 | + assert asyncio.run(result) is expected, expression |
| 342 | + else: |
| 343 | + assert result is expected, expression |
| 344 | + |
| 345 | + |
250 | 346 | @pytest.mark.xfail(reason="TODO: Optimize so that expressios are evaluated only once") |
251 | 347 | def test_should_evaluate_values_only_once(expression, expected, caplog, hooks_called): |
252 | 348 | caplog.set_level(logging.DEBUG, logger="tests") |
|
0 commit comments