Skip to content

Commit caa5046

Browse files
committed
Support Cast node in pushdown logic
1 parent 261a68a commit caa5046

2 files changed

Lines changed: 44 additions & 0 deletions

File tree

duckdb/polars_io.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,13 @@ def _pl_tree_to_sql(tree: _ExpressionTree) -> str:
159159
msg = f"Unsupported function type: {func_dict}"
160160
raise NotImplementedError(msg)
161161

162+
if node_type == "Cast":
163+
cast_tree = tree[node_type]
164+
assert isinstance(cast_tree, dict), f"A {node_type} should be a dict but got {type(cast_tree)}"
165+
cast_expr = cast_tree["expr"]
166+
assert isinstance(cast_expr, dict), f"A {node_type} should be a dict but got {type(cast_expr)}"
167+
return _pl_tree_to_sql(cast_expr)
168+
162169
if node_type == "Scalar":
163170
# Detect format: old style (dtype/value) or new style (direct type key)
164171
scalar_tree = tree[node_type]

tests/fast/arrow/test_polars.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -702,3 +702,40 @@ def test_decimal_scale(self):
702702
} }
703703
"""
704704
assert _pl_tree_to_sql(json.loads(scalar_decimal_scale)) == "1"
705+
706+
def test_cast_node_unwraps_inner_expression(self):
707+
"""Cast nodes should be unwrapped to process the inner expression."""
708+
# A Cast wrapping a Column reference
709+
cast_column = json.loads(
710+
'{"Cast": {"expr": {"Column": "a"}, "dtype": {"Decimal": [20, 0]}, "options": "NonStrict"}}'
711+
)
712+
assert _pl_tree_to_sql(cast_column) == '"a"'
713+
714+
# A Cast wrapping a full binary expression
715+
cast_expr = json.loads("""
716+
{
717+
"BinaryExpr": {
718+
"left": {"Cast": {"expr": {"Column": "a"}, "dtype": {"Decimal": [20, 0]}, "options": "NonStrict"}},
719+
"op": "Eq",
720+
"right": {"Literal": {"Int": 1}}
721+
}
722+
}
723+
""")
724+
assert _pl_tree_to_sql(cast_expr) == '("a" = 1)'
725+
726+
def test_cast_node_predicate_pushdown(self):
727+
"""Predicates with Cast nodes should be successfully pushed down."""
728+
# A decimal with non-38 precision produces a Cast node in Polars
729+
expr = pl.col("a") == pl.lit(1, dtype=pl.Decimal(precision=20, scale=0))
730+
valid_filter(expr)
731+
732+
def test_polars_lazy_pushdown_decimal_with_cast(self):
733+
"""End-to-end test: decimal columns with non-38 precision should push down filters."""
734+
con = duckdb.connect()
735+
con.execute("CREATE TABLE test_cast (a DECIMAL(20,0))")
736+
con.execute("INSERT INTO test_cast VALUES (1), (10), (100), (NULL)")
737+
rel = con.sql("FROM test_cast")
738+
lazy_df = rel.pl(lazy=True)
739+
740+
assert lazy_df.filter(pl.col("a") == 1).collect().to_dicts() == [{"a": 1}]
741+
assert lazy_df.filter(pl.col("a") > 1).collect().to_dicts() == [{"a": 10}, {"a": 100}]

0 commit comments

Comments
 (0)