Skip to content

Commit bf6b47c

Browse files
Support Cast node in pushdown logic (#298)
Related: #98 When the decimal precision is anything other than 38, Polars wraps the expression in a Cast node. `_pl_tree_to_sql` didn't handle Cast nodes, causing silent pushdown failure. This fix makes sure these nodes get pushed down as well. ``` import duckdb import polars as pl import json # Create a decimal column with precision != 38 con = duckdb.connect() rel = con.sql("SELECT a::DECIMAL(20,0) AS a FROM range(10) AS t(a)") lazy_df = rel.pl(lazy=True) # This filter works but pushdown silently fails result = lazy_df.filter(pl.col("a") == 1).collect() print(f"Result: {len(result)} rows") # Returns 1 row (correct) # Polars serialized it as follows expr = pl.col("a") == 1 tree = json.loads(expr.meta.serialize(format="json")) print(json.dumps(tree, indent=2)) # Contains: "Cast": {"expr": ..., "dtype": {"Decimal": [20, 0]}, ...} ``` Note that we have to also check for the strictness of the cast node: - NonStrict casts are pushable, they are auto-inserted type coercion casts (like DECIMAL(20,0) to DECIMAL(38,0)) - Strict casts are fallible (and can error, see pola-rs/polars#22669)
2 parents e21cf62 + e93d591 commit bf6b47c

2 files changed

Lines changed: 53 additions & 0 deletions

File tree

duckdb/polars_io.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,16 @@ def _pl_tree_to_sql(tree: _ExpressionTree) -> str:
159159
msg = f"Unsupported function type: {func_dict}"
160160
raise NotImplementedError(msg)
161161

162+
if node_type == "Cast":
163+
cast_tree = tree[node_type]
164+
assert isinstance(cast_tree, dict), f"A {node_type} should be a dict but got {type(cast_tree)}"
165+
if cast_tree.get("options") != "NonStrict":
166+
msg = f"Only NonStrict casts can be safely unwrapped, got {cast_tree.get('options')!r}"
167+
raise NotImplementedError(msg)
168+
cast_expr = cast_tree["expr"]
169+
assert isinstance(cast_expr, dict), f"A {node_type} should be a dict but got {type(cast_expr)}"
170+
return _pl_tree_to_sql(cast_expr)
171+
162172
if node_type == "Scalar":
163173
# Detect format: old style (dtype/value) or new style (direct type key)
164174
scalar_tree = tree[node_type]

tests/fast/arrow/test_polars.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -702,3 +702,46 @@ def test_decimal_scale(self):
702702
} }
703703
"""
704704
assert _pl_tree_to_sql(json.loads(scalar_decimal_scale)) == "1"
705+
706+
def test_cast_node_unwraps_inner_expression(self):
707+
"""Cast nodes should be unwrapped to process the inner expression."""
708+
# A Cast wrapping a Column reference
709+
cast_column = json.loads(
710+
'{"Cast": {"expr": {"Column": "a"}, "dtype": {"Decimal": [20, 0]}, "options": "NonStrict"}}'
711+
)
712+
assert _pl_tree_to_sql(cast_column) == '"a"'
713+
714+
# A Cast wrapping a full binary expression
715+
cast_expr = json.loads("""
716+
{
717+
"BinaryExpr": {
718+
"left": {"Cast": {"expr": {"Column": "a"}, "dtype": {"Decimal": [20, 0]}, "options": "NonStrict"}},
719+
"op": "Eq",
720+
"right": {"Literal": {"Int": 1}}
721+
}
722+
}
723+
""")
724+
assert _pl_tree_to_sql(cast_expr) == '("a" = 1)'
725+
726+
def test_cast_node_predicate_pushdown(self):
727+
"""Predicates with Cast nodes should be successfully pushed down."""
728+
# A decimal with non-38 precision produces a Cast node in Polars
729+
expr = pl.col("a") == pl.lit(1, dtype=pl.Decimal(precision=20, scale=0))
730+
valid_filter(expr)
731+
732+
def test_polars_lazy_pushdown_decimal_with_cast(self):
733+
"""End-to-end test: decimal columns with non-38 precision should push down filters."""
734+
con = duckdb.connect()
735+
con.execute("CREATE TABLE test_cast (a DECIMAL(20,0))")
736+
con.execute("INSERT INTO test_cast VALUES (1), (10), (100), (NULL)")
737+
rel = con.sql("FROM test_cast")
738+
lazy_df = rel.pl(lazy=True)
739+
740+
assert lazy_df.filter(pl.col("a") == 1).collect().to_dicts() == [{"a": 1}]
741+
assert lazy_df.filter(pl.col("a") > 1).collect().to_dicts() == [{"a": 10}, {"a": 100}]
742+
743+
def test_explicit_cast_not_pushed_down(self):
744+
"""Explicit user .cast() (Strict) should not be pushed down - falls back to Polars."""
745+
# pl.col("a").cast(pl.Int64) produces a Strict Cast node
746+
expr = pl.col("a").cast(pl.Int64) > 5
747+
invalid_filter(expr)

0 commit comments

Comments
 (0)