Skip to content

Commit b8a4bac

Browse files
timsaucerclaude
andcommitted
Add map functions (make_map, map_keys, map_values, map_extract, map_entries, element_at)
Closes #1448 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent be8dd9d commit b8a4bac

2 files changed

Lines changed: 147 additions & 0 deletions

File tree

crates/core/src/functions.rs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,18 @@ fn array_cat(exprs: Vec<PyExpr>) -> PyExpr {
9393
array_concat(exprs)
9494
}
9595

96+
#[pyfunction]
97+
fn make_map(keys: Vec<PyExpr>, values: Vec<PyExpr>) -> PyExpr {
98+
let keys = keys.into_iter().map(|x| x.into()).collect();
99+
let values = values.into_iter().map(|x| x.into()).collect();
100+
datafusion::functions_nested::map::map(keys, values).into()
101+
}
102+
103+
#[pyfunction]
104+
fn element_at(map: PyExpr, key: PyExpr) -> PyExpr {
105+
datafusion::functions_nested::expr_fn::map_extract(map.into(), key.into()).into()
106+
}
107+
96108
#[pyfunction]
97109
#[pyo3(signature = (array, element, index=None))]
98110
fn array_position(array: PyExpr, element: PyExpr, index: Option<i64>) -> PyExpr {
@@ -666,6 +678,12 @@ array_fn!(cardinality, array);
666678
array_fn!(flatten, array);
667679
array_fn!(range, start stop step);
668680

681+
// Map Functions
682+
array_fn!(map_keys, map);
683+
array_fn!(map_values, map);
684+
array_fn!(map_extract, map key);
685+
array_fn!(map_entries, map);
686+
669687
aggregate_function!(array_agg);
670688
aggregate_function!(max);
671689
aggregate_function!(min);
@@ -1126,6 +1144,14 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
11261144
m.add_wrapped(wrap_pyfunction!(flatten))?;
11271145
m.add_wrapped(wrap_pyfunction!(cardinality))?;
11281146

1147+
// Map Functions
1148+
m.add_wrapped(wrap_pyfunction!(make_map))?;
1149+
m.add_wrapped(wrap_pyfunction!(map_keys))?;
1150+
m.add_wrapped(wrap_pyfunction!(map_values))?;
1151+
m.add_wrapped(wrap_pyfunction!(map_extract))?;
1152+
m.add_wrapped(wrap_pyfunction!(map_entries))?;
1153+
m.add_wrapped(wrap_pyfunction!(element_at))?;
1154+
11291155
// Window Functions
11301156
m.add_wrapped(wrap_pyfunction!(lead))?;
11311157
m.add_wrapped(wrap_pyfunction!(lag))?;

python/datafusion/functions.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from __future__ import annotations
2020

21+
import builtins
2122
from typing import TYPE_CHECKING, Any
2223

2324
import pyarrow as pa
@@ -139,6 +140,7 @@
139140
"degrees",
140141
"dense_rank",
141142
"digest",
143+
"element_at",
142144
"empty",
143145
"encode",
144146
"ends_with",
@@ -202,7 +204,12 @@
202204
"make_array",
203205
"make_date",
204206
"make_list",
207+
"make_map",
205208
"make_time",
209+
"map_entries",
210+
"map_extract",
211+
"map_keys",
212+
"map_values",
206213
"max",
207214
"md5",
208215
"mean",
@@ -3374,6 +3381,120 @@ def empty(array: Expr) -> Expr:
33743381
return array_empty(array)
33753382

33763383

3384+
# map functions
3385+
3386+
3387+
def make_map(*args: Expr) -> Expr:
3388+
"""Returns a map created from key and value expressions.
3389+
3390+
Accepts an even number of arguments, alternating between keys and values.
3391+
For example, ``make_map(k1, v1, k2, v2)`` creates a map ``{k1: v1, k2: v2}``.
3392+
3393+
Examples:
3394+
>>> ctx = dfn.SessionContext()
3395+
>>> df = ctx.from_pydict({"a": [1]})
3396+
>>> result = df.select(
3397+
... dfn.functions.make_map(
3398+
... dfn.lit("a"), dfn.lit(1),
3399+
... dfn.lit("b"), dfn.lit(2),
3400+
... ).alias("map"))
3401+
>>> result.collect_column("map")[0].as_py()
3402+
[('a', 1), ('b', 2)]
3403+
"""
3404+
if len(args) % 2 != 0:
3405+
msg = "make_map requires an even number of arguments"
3406+
raise ValueError(msg)
3407+
keys = [args[i].expr for i in builtins.range(0, len(args), 2)]
3408+
values = [args[i].expr for i in builtins.range(1, len(args), 2)]
3409+
return Expr(f.make_map(keys, values))
3410+
3411+
3412+
def map_keys(map: Expr) -> Expr:
3413+
"""Returns a list of all keys in the map.
3414+
3415+
Examples:
3416+
>>> ctx = dfn.SessionContext()
3417+
>>> df = ctx.from_pydict({"a": [1]})
3418+
>>> result = df.select(
3419+
... dfn.functions.map_keys(
3420+
... dfn.functions.make_map(
3421+
... dfn.lit("x"), dfn.lit(1),
3422+
... dfn.lit("y"), dfn.lit(2),
3423+
... )
3424+
... ).alias("keys"))
3425+
>>> result.collect_column("keys")[0].as_py()
3426+
['x', 'y']
3427+
"""
3428+
return Expr(f.map_keys(map.expr))
3429+
3430+
3431+
def map_values(map: Expr) -> Expr:
3432+
"""Returns a list of all values in the map.
3433+
3434+
Examples:
3435+
>>> ctx = dfn.SessionContext()
3436+
>>> df = ctx.from_pydict({"a": [1]})
3437+
>>> result = df.select(
3438+
... dfn.functions.map_values(
3439+
... dfn.functions.make_map(
3440+
... dfn.lit("x"), dfn.lit(1),
3441+
... dfn.lit("y"), dfn.lit(2),
3442+
... )
3443+
... ).alias("vals"))
3444+
>>> result.collect_column("vals")[0].as_py()
3445+
[1, 2]
3446+
"""
3447+
return Expr(f.map_values(map.expr))
3448+
3449+
3450+
def map_extract(map: Expr, key: Expr) -> Expr:
3451+
"""Returns the value for the given key in the map, or an empty list if absent.
3452+
3453+
Examples:
3454+
>>> ctx = dfn.SessionContext()
3455+
>>> df = ctx.from_pydict({"a": [1]})
3456+
>>> result = df.select(
3457+
... dfn.functions.map_extract(
3458+
... dfn.functions.make_map(
3459+
... dfn.lit("x"), dfn.lit(1),
3460+
... dfn.lit("y"), dfn.lit(2),
3461+
... ),
3462+
... dfn.lit("x"),
3463+
... ).alias("val"))
3464+
>>> result.collect_column("val")[0].as_py()
3465+
[1]
3466+
"""
3467+
return Expr(f.map_extract(map.expr, key.expr))
3468+
3469+
3470+
def map_entries(map: Expr) -> Expr:
3471+
"""Returns a list of all entries (key-value struct pairs) in the map.
3472+
3473+
Examples:
3474+
>>> ctx = dfn.SessionContext()
3475+
>>> df = ctx.from_pydict({"a": [1]})
3476+
>>> result = df.select(
3477+
... dfn.functions.map_entries(
3478+
... dfn.functions.make_map(
3479+
... dfn.lit("x"), dfn.lit(1),
3480+
... dfn.lit("y"), dfn.lit(2),
3481+
... )
3482+
... ).alias("entries"))
3483+
>>> result.collect_column("entries")[0].as_py()
3484+
[{'key': 'x', 'value': 1}, {'key': 'y', 'value': 2}]
3485+
"""
3486+
return Expr(f.map_entries(map.expr))
3487+
3488+
3489+
def element_at(map: Expr, key: Expr) -> Expr:
3490+
"""Returns the value for the given key in the map, or an empty list if absent.
3491+
3492+
See Also:
3493+
This is an alias for :py:func:`map_extract`.
3494+
"""
3495+
return map_extract(map, key)
3496+
3497+
33773498
# aggregate functions
33783499
def approx_distinct(
33793500
expression: Expr,

0 commit comments

Comments
 (0)