diff --git a/duckdb/experimental/spark/sql/functions.py b/duckdb/experimental/spark/sql/functions.py index 71ff8c59..928b5c06 100644 --- a/duckdb/experimental/spark/sql/functions.py +++ b/duckdb/experimental/spark/sql/functions.py @@ -1893,6 +1893,77 @@ def array_remove(col: "ColumnOrName", element: Any) -> Column: # noqa: ANN401 ) +def explode(col: "ColumnOrName") -> Column: + """Returns a new row for each element in the given array or map. + + Uses the default column name ``col`` for elements in the array + and ``key`` and ``value`` for elements in the map unless specified otherwise. + + .. versionadded:: 1.3.0 + + .. versionchanged:: 3.4.0 + Supports Spark Connect. + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + target column to compute on. + + Returns: + ------- + :class:`~pyspark.sql.Column` + one row per array element or map entry. + + Notes: + ----- + Rows with ``NULL`` or empty arrays/maps are dropped from the output. + + Examples: + -------- + >>> df = spark.createDataFrame([(1, [1, 2, 3]), (2, [4, 5])], ["id", "data"]) + >>> df.select("id", explode("data").alias("val")).collect() + [Row(id=1, val=1), Row(id=1, val=2), Row(id=1, val=3), Row(id=2, val=4), Row(id=2, val=5)] + """ + return Column(FunctionExpression("unnest", _to_column_expr(col))) + + +def explode_outer(col: "ColumnOrName") -> Column: + """Returns a new row for each element in the given array or map. + + Unlike explode, if the array/map is ``NULL`` or empty, the row (``NULL``) is produced. + Uses the default column name ``col`` for elements in the array + and ``key`` and ``value`` for elements in the map unless specified otherwise. + + .. versionadded:: 2.3.0 + + .. versionchanged:: 3.4.0 + Supports Spark Connect. + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + target column to compute on. + + Returns: + ------- + :class:`~pyspark.sql.Column` + one row per array element or map entry, with ``NULL`` for empty/``NULL`` inputs. + + Examples: + -------- + >>> df = spark.createDataFrame([(1, [1, 2]), (2, None), (3, [])], "id: int, data: array") + >>> df.select("id", explode_outer("data").alias("val")).collect() + [Row(id=1, val=1), Row(id=1, val=2), Row(id=2, val=None), Row(id=3, val=None)] + """ + col_expr = _to_column_expr(col) + is_null = col_expr.isnull() + is_empty = FunctionExpression("array_length", col_expr).__eq__(ConstantExpression(0)) + null_or_empty = is_null.__or__(is_empty) + null_list = FunctionExpression("list_value", ConstantExpression(None)) + case_expr = CaseExpression(null_or_empty, null_list).otherwise(col_expr) + return Column(FunctionExpression("unnest", case_expr)) + + def last_day(date: "ColumnOrName") -> Column: """Returns the last day of the month which the given date belongs to. diff --git a/tests/fast/spark/test_spark_functions_array.py b/tests/fast/spark/test_spark_functions_array.py index 9ee2ffc2..047c7fe9 100644 --- a/tests/fast/spark/test_spark_functions_array.py +++ b/tests/fast/spark/test_spark_functions_array.py @@ -235,3 +235,48 @@ def test_arrays_zip(self, spark): ] else: assert res == [Row(zipped=[(1, 2, 3), (2, 4, 6), (3, 6, None)])] + + def test_explode(self, spark): + df = spark.createDataFrame([(1, [10, 20, 30]), (2, [40, 50])], ["id", "data"]) + + res = df.select("id", sf.explode("data").alias("val")).collect() + assert res == [ + Row(id=1, val=10), + Row(id=1, val=20), + Row(id=1, val=30), + Row(id=2, val=40), + Row(id=2, val=50), + ] + + def test_explode_drops_null_and_empty(self, spark): + df = spark.createDataFrame([(1, [1, 2]), (2, None), (3, [])], ["id", "data"]) + + res = df.select("id", sf.explode("data").alias("val")).collect() + assert res == [Row(id=1, val=1), Row(id=1, val=2)] + + def test_explode_with_column_object(self, spark): + df = spark.createDataFrame([([1, 2, 3],)], ["data"]) + + res = df.select(sf.explode(df.data).alias("val")).collect() + assert res == [Row(val=1), Row(val=2), Row(val=3)] + + def test_explode_outer(self, spark): + df = spark.createDataFrame([(1, [1, 2]), (2, None), (3, [])], ["id", "data"]) + + res = df.select("id", sf.explode_outer("data").alias("val")).collect() + assert res == [ + Row(id=1, val=1), + Row(id=1, val=2), + Row(id=2, val=None), + Row(id=3, val=None), + ] + + def test_explode_outer_all_populated(self, spark): + df = spark.createDataFrame([(1, [10, 20]), (2, [30])], ["id", "data"]) + + res = df.select("id", sf.explode_outer("data").alias("val")).collect() + assert res == [ + Row(id=1, val=10), + Row(id=1, val=20), + Row(id=2, val=30), + ]