Skip to content

Commit 55013e1

Browse files
Allow strict casts of literals only in polars lazyframe pushdown (#348)
2 parents c3cf406 + b8d19d0 commit 55013e1

2 files changed

Lines changed: 14 additions & 4 deletions

File tree

duckdb/polars_io.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,17 @@ def _pl_tree_to_sql(tree: _ExpressionTree) -> str:
170170
if node_type == "Cast":
171171
cast_tree = tree[node_type]
172172
assert isinstance(cast_tree, dict), f"A {node_type} should be a dict but got {type(cast_tree)}"
173-
if cast_tree.get("options") not in ("NonStrict", "Strict"):
174-
msg = f"Only NonStrict/Strict casts can be safely unwrapped, got {cast_tree.get('options')!r}"
173+
options = cast_tree.get("options")
174+
if options == "Strict":
175+
# Strict casts on literals (e.g. pl.lit(1, dtype=pl.Int8)) are safe to unwrap —
176+
# the value is known at expression creation time. Strict casts on columns
177+
# (e.g. pl.col("a").cast(pl.Int64)) are semantically meaningful and must not be dropped.
178+
cast_expr = cast_tree.get("expr", {})
179+
if not isinstance(cast_expr, dict) or "Literal" not in cast_expr:
180+
msg = "Strict cast on non-literal expression cannot be pushed down"
181+
raise NotImplementedError(msg)
182+
elif options != "NonStrict":
183+
msg = f"Only NonStrict/Strict casts can be safely unwrapped, got {options!r}"
175184
raise NotImplementedError(msg)
176185
cast_expr = cast_tree["expr"]
177186
assert isinstance(cast_expr, dict), f"A {node_type} should be a dict but got {type(cast_expr)}"

tests/fast/arrow/test_polars.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from duckdb.polars_io import _pl_tree_to_sql, _predicate_to_expression # noqa: E402
1414

15+
pl_pre_1_35_0 = parse_version(pl.__version__) < parse_version("1.35.0")
1516
pl_pre_1_36_0 = parse_version(pl.__version__) < parse_version("1.36.0")
1617

1718

@@ -437,7 +438,7 @@ def test_polars_lazy_pushdown_timestamp(self, duckdb_cursor):
437438
lazy_df.filter((pl.col("a") == ts_2020) | (pl.col("b") == ts_2008)).select(pl.len()).collect().item() == 2
438439
)
439440

440-
@pytest.mark.skipif(pl_pre_1_36_0, reason="Polars < 1.36.0 expressions on dates produce casts in predicates")
441+
@pytest.mark.skipif(pl_pre_1_35_0, reason="Polars < 1.36.0 expressions on dates produce casts in predicates")
441442
def test_polars_predicate_to_expression_post_1_36_0(self):
442443
ts_2008 = datetime.datetime(2008, 1, 1, 0, 0, 1)
443444
ts_2010 = datetime.datetime(2010, 1, 1, 10, 0, 1)
@@ -454,7 +455,7 @@ def test_polars_predicate_to_expression_post_1_36_0(self):
454455
valid_filter((pl.col("a") == ts_2020) & (pl.col("b") == ts_2010) & (pl.col("c") == ts_2020))
455456
valid_filter((pl.col("a") == ts_2020) | (pl.col("b") == ts_2008))
456457

457-
@pytest.mark.skipif(not pl_pre_1_36_0, reason="Polars >= 1.36.0 expressions on dates don't produce casts")
458+
@pytest.mark.skipif(not pl_pre_1_35_0, reason="Polars >= 1.36.0 expressions on dates don't produce casts")
458459
def test_polars_predicate_to_expression_pre_1_36_0(self):
459460
ts_2008 = datetime.datetime(2008, 1, 1, 0, 0, 1)
460461
ts_2010 = datetime.datetime(2010, 1, 1, 10, 0, 1)

0 commit comments

Comments
 (0)