Skip to content

Commit 3db159a

Browse files
committed
feat: update Expr.literal to handle None and convert values to arrays
1 parent 0ef75b1 commit 3db159a

3 files changed

Lines changed: 33 additions & 11 deletions

File tree

python/datafusion/expr.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -444,11 +444,14 @@ def literal(value: Any) -> Expr:
444444
445445
``value`` must be a valid PyArrow scalar value or easily castable to one.
446446
"""
447-
if isinstance(value, str):
448-
value = pa.scalar(value, type=pa.string_view())
449-
if not isinstance(value, pa.Scalar):
450-
value = pa.scalar(value)
451-
return Expr(expr_internal.RawExpr.literal(value))
447+
if value is None:
448+
array = pa.array([None])
449+
elif isinstance(value, str):
450+
array = pa.array([value], type=pa.string_view())
451+
else:
452+
array = pa.array([value])
453+
454+
return Expr(expr_internal.RawExpr.literal(array))
452455

453456
@staticmethod
454457
def literal_with_metadata(value: Any, metadata: dict[str, str]) -> Expr:
@@ -458,11 +461,14 @@ def literal_with_metadata(value: Any, metadata: dict[str, str]) -> Expr:
458461
value: A valid PyArrow scalar value or easily castable to one.
459462
metadata: Metadata to attach to the expression.
460463
"""
461-
if isinstance(value, str):
462-
value = pa.scalar(value, type=pa.string_view())
463-
value = value if isinstance(value, pa.Scalar) else pa.scalar(value)
464+
if value is None:
465+
array = pa.array([None])
466+
elif isinstance(value, str):
467+
array = pa.array([value], type=pa.string_view())
468+
else:
469+
array = pa.array([value])
464470

465-
return Expr(expr_internal.RawExpr.literal_with_metadata(value, metadata))
471+
return Expr(expr_internal.RawExpr.literal_with_metadata(array, metadata))
466472

467473
@staticmethod
468474
def string_literal(value: str) -> Expr:
@@ -476,8 +482,8 @@ def string_literal(value: str) -> Expr:
476482
https://github.com/apache/datafusion/blob/86740bfd3d9831d6b7c1d0e1bf4a21d91598a0ac/datafusion/functions/src/core/arrow_cast.rs#L179
477483
"""
478484
if isinstance(value, str):
479-
value = pa.scalar(value, type=pa.string())
480-
return Expr(expr_internal.RawExpr.literal(value))
485+
array = pa.array([value], type=pa.string())
486+
return Expr(expr_internal.RawExpr.literal(array))
481487
return Expr.literal(value)
482488

483489
@staticmethod

python/tests/test_dataframe.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,13 @@ def test_limit_with_offset(df):
302302
assert len(result.column(1)) == 1
303303

304304

305+
def test_literal_numeric_and_string(df):
306+
df = df.select(literal(10), literal("foo"))
307+
result = df.collect()[0]
308+
assert result.column(0) == pa.array([10, 10, 10])
309+
assert result.column(1) == pa.array(["foo", "foo", "foo"], type=pa.string_view())
310+
311+
305312
def test_head(df):
306313
df = df.head(1)
307314

python/tests/test_functions.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import pytest
2323
from datafusion import SessionContext, column, literal, string_literal
2424
from datafusion import functions as f
25+
from datafusion.expr import Expr
2526

2627
np.seterr(invalid="ignore")
2728

@@ -110,6 +111,14 @@ def test_lit_arith(df):
110111
)
111112

112113

114+
def test_expr_literal_numeric_and_string():
115+
"""Ensure Expr.literal handles numeric and string values."""
116+
num_expr = Expr.literal(7)
117+
str_expr = Expr.literal("bar")
118+
assert num_expr.to_variant() == 7
119+
assert str_expr.to_variant() == "bar"
120+
121+
113122
def test_math_functions():
114123
ctx = SessionContext()
115124
# create a RecordBatch and a new DataFrame from it

0 commit comments

Comments
 (0)