Skip to content

Commit 398980d

Browse files
authored
Support None comparisons for null expressions (#1489)
* Support None comparisons for null expressions * Fold None comparison coverage into relational expr test
1 parent 8a7efea commit 398980d

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

python/datafusion/expr.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,8 @@ def __eq__(self, rhs: object) -> Expr:
483483
484484
Accepts either an expression or any valid PyArrow scalar literal value.
485485
"""
486+
if rhs is None:
487+
return self.is_null()
486488
if not isinstance(rhs, Expr):
487489
rhs = Expr.literal(rhs)
488490
return Expr(self.expr.__eq__(rhs.expr))
@@ -492,6 +494,8 @@ def __ne__(self, rhs: object) -> Expr:
492494
493495
Accepts either an expression or any valid PyArrow scalar literal value.
494496
"""
497+
if rhs is None:
498+
return self.is_not_null()
495499
if not isinstance(rhs, Expr):
496500
rhs = Expr.literal(rhs)
497501
return Expr(self.expr.__ne__(rhs.expr))

python/tests/test_expr.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,8 @@ def test_relational_expr(test_ctx):
153153

154154
batch = pa.RecordBatch.from_arrays(
155155
[
156-
pa.array([1, 2, 3]),
157-
pa.array(["alpha", "beta", "gamma"], type=pa.string_view()),
156+
pa.array([1, 2, 3, None]),
157+
pa.array(["alpha", "beta", "gamma", None], type=pa.string_view()),
158158
],
159159
names=["a", "b"],
160160
)
@@ -171,6 +171,10 @@ def test_relational_expr(test_ctx):
171171
assert df.filter(col("b") != "beta").count() == 2
172172

173173
assert df.filter(col("a") == "beta").count() == 0
174+
assert df.filter(col("a") == None).count() == 1 # noqa: E711
175+
assert df.filter(col("a") != None).count() == 3 # noqa: E711
176+
assert df.filter(col("b") == None).count() == 1 # noqa: E711
177+
assert df.filter(col("b") != None).count() == 3 # noqa: E711
174178

175179

176180
def test_expr_to_variant():

0 commit comments

Comments
 (0)