Skip to content

Commit 07c2741

Browse files
committed
Allow plain Python literals in regexp function wrappers
1 parent 3585c11 commit 07c2741

3 files changed

Lines changed: 85 additions & 27 deletions

File tree

python/datafusion/expr.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,23 @@ def _iter(
276276

277277
return list(_iter(exprs))
278278

279+
def _to_raw_literal_expr(value: Expr | Any) -> expr_internal.Expr:
280+
"""Convert an expression or Python literal to its raw variant.
281+
282+
Args:
283+
value: Candidate expression or Python literal value.
284+
285+
Returns:
286+
The internal :class:`~datafusion._internal.expr.Expr` representation.
287+
288+
Examples:
289+
>>> expr = Expr(_to_raw_literal_expr(1))
290+
>>> isinstance(expr, Expr)
291+
True
292+
"""
293+
if isinstance(value, Expr):
294+
return value.expr
295+
return Expr.literal(value).expr
279296

280297
def _to_raw_expr(value: Expr | str) -> expr_internal.Expr:
281298
"""Convert a Python expression or column name to its raw variant.

python/datafusion/functions.py

Lines changed: 44 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
expr_list_to_raw_expr_list,
3333
sort_list_to_raw_sort_list,
3434
sort_or_default,
35+
_to_raw_literal_expr,
3536
)
3637

3738
__all__ = [
@@ -1440,7 +1441,7 @@ def radians(arg: Expr) -> Expr:
14401441
return Expr(f.radians(arg.expr))
14411442

14421443

1443-
def regexp_like(string: Expr, regex: Expr, flags: Expr | None = None) -> Expr:
1444+
def regexp_like(string: Expr, regex: Expr | Any, flags: Expr | Any | None = None) -> Expr:
14441445
r"""Find if any regular expression (regex) matches exist.
14451446
14461447
Tests a string using a regular expression returning true if at least one match,
@@ -1468,12 +1469,14 @@ def regexp_like(string: Expr, regex: Expr, flags: Expr | None = None) -> Expr:
14681469
>>> result.collect_column("m")[0].as_py()
14691470
True
14701471
"""
1471-
if flags is not None:
1472-
flags = flags.expr
1473-
return Expr(f.regexp_like(string.expr, regex.expr, flags))
1472+
# if flags is not None:
1473+
# flags = flags.expr
1474+
# return Expr(f.regexp_like(string.expr, regex.expr, flags))
1475+
flags = _to_raw_literal_expr(flags) if flags is not None else None
1476+
return Expr(f.regexp_like(string.expr, _to_raw_literal_expr(regex), flags))
14741477

14751478

1476-
def regexp_match(string: Expr, regex: Expr, flags: Expr | None = None) -> Expr:
1479+
def regexp_match(string: Expr, regex: Expr | Any, flags: Expr | Any | None = None) -> Expr:
14771480
r"""Perform regular expression (regex) matching.
14781481
14791482
Returns an array with each element containing the leftmost-first match of the
@@ -1501,13 +1504,15 @@ def regexp_match(string: Expr, regex: Expr, flags: Expr | None = None) -> Expr:
15011504
>>> result.collect_column("m")[0].as_py()
15021505
['hello']
15031506
"""
1504-
if flags is not None:
1505-
flags = flags.expr
1506-
return Expr(f.regexp_match(string.expr, regex.expr, flags))
1507+
# if flags is not None:
1508+
# flags = flags.expr
1509+
# return Expr(f.regexp_match(string.expr, regex.expr, flags))
1510+
flags = _to_raw_literal_expr(flags) if flags is not None else None
1511+
return Expr(f.regexp_match(string.expr, _to_raw_literal_expr(regex), flags))
15071512

15081513

15091514
def regexp_replace(
1510-
string: Expr, pattern: Expr, replacement: Expr, flags: Expr | None = None
1515+
string: Expr, pattern: Expr | Any, replacement: Expr | Any, flags: Expr | Any | None = None
15111516
) -> Expr:
15121517
r"""Replaces substring(s) matching a PCRE-like regular expression.
15131518
@@ -1541,13 +1546,17 @@ def regexp_replace(
15411546
>>> result.collect_column("r")[0].as_py()
15421547
'aX bX cX'
15431548
"""
1544-
if flags is not None:
1545-
flags = flags.expr
1546-
return Expr(f.regexp_replace(string.expr, pattern.expr, replacement.expr, flags))
1549+
# if flags is not None:
1550+
# flags = flags.expr
1551+
# return Expr(f.regexp_replace(string.expr, pattern.expr, replacement.expr, flags))
1552+
flags = _to_raw_literal_expr(flags) if flags is not None else None
1553+
pattern = _to_raw_literal_expr(pattern)
1554+
replacement = _to_raw_literal_expr(replacement)
1555+
return Expr(f.regexp_replace(string.expr, pattern, replacement, flags))
15471556

15481557

15491558
def regexp_count(
1550-
string: Expr, pattern: Expr, start: Expr | None = None, flags: Expr | None = None
1559+
string: Expr, pattern: Expr | Any, start: Expr | Any | None = None, flags: Expr | Any | None = None
15511560
) -> Expr:
15521561
"""Returns the number of matches in a string.
15531562
@@ -1575,19 +1584,22 @@ def regexp_count(
15751584
>>> result.collect_column("c")[0].as_py()
15761585
1
15771586
"""
1578-
if flags is not None:
1579-
flags = flags.expr
1580-
start = start.expr if start is not None else start
1581-
return Expr(f.regexp_count(string.expr, pattern.expr, start, flags))
1587+
# if flags is not None:
1588+
# flags = flags.expr
1589+
# start = start.expr if start is not None else start
1590+
# return Expr(f.regexp_count(string.expr, pattern.expr, start, flags))
1591+
flags = _to_raw_literal_expr(flags) if flags is not None else None
1592+
start = _to_raw_literal_expr(start) if start is not None else None
1593+
return Expr(f.regexp_count(string.expr, _to_raw_literal_expr(pattern), start, flags))
15821594

15831595

15841596
def regexp_instr(
15851597
values: Expr,
1586-
regex: Expr,
1587-
start: Expr | None = None,
1588-
n: Expr | None = None,
1589-
flags: Expr | None = None,
1590-
sub_expr: Expr | None = None,
1598+
regex: Expr | Any,
1599+
start: Expr | Any | None = None,
1600+
n: Expr | Any | None = None,
1601+
flags: Expr | Any | None = None,
1602+
sub_expr: Expr | Any | None = None,
15911603
) -> Expr:
15921604
r"""Returns the position of a regular expression match in a string.
15931605
@@ -1635,15 +1647,20 @@ def regexp_instr(
16351647
>>> result.collect_column("pos")[0].as_py()
16361648
1
16371649
"""
1638-
start = start.expr if start is not None else None
1639-
n = n.expr if n is not None else None
1640-
flags = flags.expr if flags is not None else None
1641-
sub_expr = sub_expr.expr if sub_expr is not None else None
1650+
# start = start.expr if start is not None else None
1651+
# n = n.expr if n is not None else None
1652+
# flags = flags.expr if flags is not None else None
1653+
# sub_expr = sub_expr.expr if sub_expr is not None else None
1654+
regex = _to_raw_literal_expr(regex)
1655+
start = _to_raw_literal_expr(start) if start is not None else None
1656+
n = _to_raw_literal_expr(n) if n is not None else None
1657+
flags = _to_raw_literal_expr(flags) if flags is not None else None
1658+
sub_expr = _to_raw_literal_expr(sub_expr) if sub_expr is not None else None
16421659

16431660
return Expr(
16441661
f.regexp_instr(
16451662
values.expr,
1646-
regex.expr,
1663+
regex,
16471664
start,
16481665
n,
16491666
flags,

python/tests/test_functions.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -932,6 +932,30 @@ def test_map_functions(func, expected):
932932
f.regexp_count(column("a"), literal("(ell|orl)")),
933933
pa.array([1, 1, 0], type=pa.int64()),
934934
),
935+
(
936+
f.regexp_like(column("a"), "(ell|orl)"),
937+
pa.array([True, True, False]),
938+
),
939+
(
940+
f.regexp_match(column("a"), "(ell|orl)"),
941+
pa.array([["ell"], ["orl"], None], type=pa.list_(pa.string_view())),
942+
),
943+
(
944+
f.regexp_replace(column("a"), "(ell|orl)", "-"),
945+
pa.array(["H-o", "W-d", "!"], type=pa.string_view()),
946+
),
947+
(
948+
f.regexp_count(column("a"), "(ell|orl)", start=1),
949+
pa.array([1, 1, 0], type=pa.int64()),
950+
),
951+
(
952+
f.regexp_count(column("a"), "(ELL|ORL)", flags="i"),
953+
pa.array([1, 1, 0], type=pa.int64()),
954+
),
955+
(
956+
f.regexp_instr(column("a"), "([lr])", n=2),
957+
pa.array([4, 4, 0], type=pa.int64()),
958+
),
935959
(
936960
f.regexp_instr(column("a"), literal("(ell|orl)")),
937961
pa.array([2, 2, 0], type=pa.int64()),

0 commit comments

Comments
 (0)