Skip to content

Commit 15737a2

Browse files
committed
refactor: improve type checking in DataFrame expression handling
1 parent 6570061 commit 15737a2

2 files changed

Lines changed: 31 additions & 6 deletions

File tree

python/datafusion/dataframe.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -478,14 +478,16 @@ def _simplify_expression(
478478
) -> list[expr_internal.Expr]:
479479
expr_list: list[expr_internal.Expr] = []
480480
for expr in exprs:
481-
if isinstance(expr, str) or (
482-
isinstance(expr, Iterable)
483-
and not isinstance(expr, Expr)
484-
and any(isinstance(inner, str) for inner in expr)
485-
):
481+
if isinstance(expr, str):
486482
raise TypeError(EXPR_TYPE_ERROR)
483+
if isinstance(expr, Iterable) and not isinstance(expr, Expr):
484+
expr_value = list(expr)
485+
if any(isinstance(inner, str) for inner in expr_value):
486+
raise TypeError(EXPR_TYPE_ERROR)
487+
else:
488+
expr_value = expr
487489
try:
488-
expr_list.extend(expr_list_to_raw_expr_list(expr))
490+
expr_list.extend(expr_list_to_raw_expr_list(expr_value))
489491
except TypeError as err:
490492
raise TypeError(EXPR_TYPE_ERROR) from err
491493
for alias, expr in named_exprs.items():
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import pyarrow as pa
2+
from datafusion import SessionContext, column
3+
4+
5+
def test_with_columns_generator():
6+
ctx = SessionContext()
7+
batch = pa.RecordBatch.from_arrays(
8+
[pa.array([1, 2, 3]), pa.array([4, 5, 6]), pa.array([8, 5, 8])],
9+
names=["a", "b", "c"],
10+
)
11+
df = ctx.from_arrow(batch)
12+
13+
def gen():
14+
for name in ["d", "e"]:
15+
yield (column("a") + column("b")).alias(name)
16+
17+
df = df.with_columns(gen())
18+
result = df.collect()[0]
19+
20+
assert result.schema.names == ["a", "b", "c", "d", "e"]
21+
expected = pa.array([5, 7, 9])
22+
assert result.column(3) == expected
23+
assert result.column(4) == expected

0 commit comments

Comments
 (0)