Skip to content

Commit 93807d4

Browse files
committed
add tree scope
1 parent c377a52 commit 93807d4

8 files changed

Lines changed: 88 additions & 150 deletions

File tree

datafusion/physical-expr-adapter/src/schema_rewriter.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ use std::sync::Arc;
2626

2727
use arrow::array::RecordBatch;
2828
use arrow::datatypes::{DataType, Field, FieldRef, SchemaRef};
29+
use datafusion_common::tree_node::ScopedTreeNode;
2930
use datafusion_common::{
3031
DataFusionError, Result, ScalarValue, exec_err,
3132
metadata::FieldMetadata,
@@ -69,7 +70,7 @@ where
6970
K: Borrow<str> + Eq + Hash,
7071
V: Borrow<ScalarValue>,
7172
{
72-
expr.transform_down(|expr| {
73+
expr.transform_down_in_scope(|expr| {
7374
if let Some(column) = expr.as_any().downcast_ref::<Column>()
7475
&& let Some(replacement_value) = replacements.get(column.name())
7576
{

datafusion/physical-expr-common/src/physical_expr.rs

Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -165,12 +165,37 @@ pub trait PhysicalExpr: Any + Send + Sync + Display + Debug + DynEq + DynHash {
165165
/// Get a list of child PhysicalExpr that provide the input for this expr.
166166
fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>>;
167167

168+
/// Get a list of child PhysicalExpr that provide the input for this expr that are in the same scope as this expression.
169+
///
170+
/// Due to the majority of expressions being in the same scope the default implementation is to call to [`Self::children`]
171+
///
172+
/// To know if specific child is considered in the same scope you can answer this simple question:
173+
/// If that child is a `Column` would that column can be evaluated with the same input schema
174+
/// Expressions like `plus`, `sum`, etc have all children in scope.
175+
/// Lambda expressions like `array_filter(list, value -> value + 1)`, have the `list` in the same scope and the lambda function in different scope
176+
fn children_in_scope(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
177+
self.children()
178+
}
179+
168180
/// Returns a new PhysicalExpr where all children were replaced by new exprs.
169181
fn with_new_children(
170182
self: Arc<Self>,
171183
children: Vec<Arc<dyn PhysicalExpr>>,
172184
) -> Result<Arc<dyn PhysicalExpr>>;
173185

186+
/// Returns a new PhysicalExpr where all scoped children were replaced by new exprs.
187+
///
188+
/// See [`Self::children_in_scope`] for definition of what child considered a scope
189+
///
190+
/// Due to the majority of expressions being in the same scope the default implementation is to call to [`Self::with_new_children`]
191+
///
192+
fn with_new_children_in_scope(
193+
self: Arc<Self>,
194+
children_in_scope: Vec<Arc<dyn PhysicalExpr>>,
195+
) -> Result<Arc<dyn PhysicalExpr>> {
196+
self.with_new_children(children_in_scope)
197+
}
198+
174199
/// Computes the output interval for the expression, given the input
175200
/// intervals.
176201
///
@@ -476,16 +501,40 @@ pub fn with_new_children_if_necessary(
476501
);
477502

478503
if children.is_empty()
479-
|| children
480-
.iter()
481-
.zip(old_children.iter())
482-
.any(|(c1, c2)| !Arc::ptr_eq(c1, c2))
504+
|| children
505+
.iter()
506+
.zip(old_children.iter())
507+
.any(|(c1, c2)| !Arc::ptr_eq(c1, c2))
483508
{
484509
Ok(expr.with_new_children(children)?)
485510
} else {
486511
Ok(expr)
487512
}
488513
}
514+
/// Returns a copy of this expr if we change any child according to the pointer comparison.
515+
/// The size of `children_in_scope` must be equal to the size of [`PhysicalExpr::children_in_scope()`].
516+
pub fn with_new_children_in_scope_if_necessary(
517+
expr: Arc<dyn PhysicalExpr>,
518+
children_in_scope: Vec<Arc<dyn PhysicalExpr>>,
519+
) -> Result<Arc<dyn PhysicalExpr>> {
520+
let old_children_in_scope = expr.children_in_scope();
521+
assert_eq_or_internal_err!(
522+
children_in_scope.len(),
523+
old_children_in_scope.len(),
524+
"PhysicalExpr: Wrong number of children in scope"
525+
);
526+
527+
if children_in_scope.is_empty()
528+
|| children_in_scope
529+
.iter()
530+
.zip(old_children_in_scope.iter())
531+
.any(|(c1, c2)| !Arc::ptr_eq(c1, c2))
532+
{
533+
Ok(expr.with_new_children_in_scope(children_in_scope)?)
534+
} else {
535+
Ok(expr)
536+
}
537+
}
489538

490539
/// Returns [`Display`] able a list of [`PhysicalExpr`]
491540
///

datafusion/physical-expr-common/src/tree_node.rs

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,10 @@
2020
use std::fmt::{self, Display, Formatter};
2121
use std::sync::Arc;
2222

23-
use crate::physical_expr::{PhysicalExpr, with_new_children_if_necessary};
23+
use crate::physical_expr::{PhysicalExpr, with_new_children_if_necessary, with_new_children_in_scope_if_necessary};
2424

2525
use datafusion_common::Result;
26-
use datafusion_common::tree_node::{ConcreteTreeNode, DynTreeNode};
26+
use datafusion_common::tree_node::{ConcreteTreeNode, DynScopedTreeNode, DynTreeNode};
2727

2828
impl DynTreeNode for dyn PhysicalExpr {
2929
fn arc_children(&self) -> Vec<&Arc<Self>> {
@@ -39,6 +39,20 @@ impl DynTreeNode for dyn PhysicalExpr {
3939
}
4040
}
4141

42+
impl DynScopedTreeNode for dyn PhysicalExpr {
43+
fn arc_children_in_scope(&self) -> Vec<&Arc<Self>> {
44+
self.children_in_scope()
45+
}
46+
47+
fn with_new_arc_children_in_scope(
48+
&self,
49+
arc_self: Arc<Self>,
50+
new_children: Vec<Arc<Self>>,
51+
) -> Result<Arc<Self>> {
52+
with_new_children_in_scope_if_necessary(arc_self, new_children)
53+
}
54+
}
55+
4256
/// A node object encapsulating a [`PhysicalExpr`] node with a payload. Since there are
4357
/// two ways to access child plans—directly from the plan and through child nodes—it's
4458
/// recommended to perform mutable operations via [`Self::update_expr_from_children`].

datafusion/physical-expr/src/expressions/case.rs

Lines changed: 7 additions & 134 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,9 @@ use std::{any::Any, sync::Arc};
4040

4141
use crate::expressions::case::literal_lookup_table::LiteralLookupTable;
4242
use arrow::compute::kernels::merge::{MergeIndex, merge, merge_n};
43-
use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion};
43+
use datafusion_common::tree_node::{
44+
ScopedTreeNode, Transformed, TreeNode, TreeNodeRecursion,
45+
};
4446
use datafusion_physical_expr_common::datum::compare_with_eq;
4547
use datafusion_physical_expr_common::utils::scatter;
4648
use itertools::Itertools;
@@ -130,7 +132,7 @@ impl CaseBody {
130132
// Determine the set of columns that are used in all the expressions of the case body.
131133
let mut used_column_indices = IndexSet::<usize>::new();
132134
let mut collect_column_indices = |expr: &Arc<dyn PhysicalExpr>| {
133-
expr.apply(|expr| {
135+
expr.apply_in_scope(|expr| {
134136
if let Some(column) = expr.as_any().downcast_ref::<Column>() {
135137
used_column_indices.insert(column.index());
136138
}
@@ -161,7 +163,7 @@ impl CaseBody {
161163
// using the column index mapping.
162164
let project = |expr: &Arc<dyn PhysicalExpr>| -> Result<Arc<dyn PhysicalExpr>> {
163165
Arc::clone(expr)
164-
.transform_down(|e| {
166+
.transform_down_in_scope(|e| {
165167
if let Some(column) = e.as_any().downcast_ref::<Column>() {
166168
let original = column.index();
167169
let projected = *column_index_map.get(&original).unwrap();
@@ -1397,7 +1399,7 @@ fn replace_with_null(
13971399
input_schema: &Schema,
13981400
) -> Result<Arc<dyn PhysicalExpr>, DataFusionError> {
13991401
let with_null = Arc::clone(expr)
1400-
.transform_down(|e| {
1402+
.transform_down_in_scope(|e| {
14011403
if e.as_ref().dyn_eq(expr_to_replace) {
14021404
let data_type = e.data_type(input_schema)?;
14031405
let null_literal = lit(ScalarValue::try_new_null(&data_type)?);
@@ -1928,135 +1930,6 @@ mod tests {
19281930
Ok(())
19291931
}
19301932

1931-
#[test]
1932-
fn case_without_expr_and_with_custom_column_impl() -> Result<()> {
1933-
/// Represents the column at a given index in a RecordBatch that is inside a Spark lambda function
1934-
///
1935-
/// This is the same as the datafusion [`datafusion::physical_expr::expressions::Column`] except that it store the entire info so that it can be used in lambda execution
1936-
#[derive(Debug, Hash, PartialEq, Eq, Clone)]
1937-
pub struct CustomColumn {
1938-
/// The name of the column (used for debugging and display purposes)
1939-
name: String,
1940-
/// The index of the column in its schema
1941-
index: usize,
1942-
data_type: DataType,
1943-
nullable: bool,
1944-
}
1945-
1946-
impl CustomColumn {
1947-
pub fn new_with_schema(
1948-
name: &str,
1949-
schema: &Schema,
1950-
) -> Result<Arc<dyn PhysicalExpr>> {
1951-
let index = schema.index_of(name)?;
1952-
let field = schema.field(index);
1953-
Ok(Arc::new(CustomColumn {
1954-
name: name.to_string(),
1955-
index,
1956-
data_type: field.data_type().clone(),
1957-
nullable: field.is_nullable(),
1958-
}))
1959-
}
1960-
}
1961-
1962-
impl std::fmt::Display for CustomColumn {
1963-
fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
1964-
write!(f, "{}@{}", self.name, self.index)
1965-
}
1966-
}
1967-
1968-
impl PhysicalExpr for CustomColumn {
1969-
fn as_any(&self) -> &dyn Any {
1970-
self
1971-
}
1972-
1973-
fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
1974-
Ok(self.data_type.clone())
1975-
}
1976-
1977-
fn nullable(&self, _input_schema: &Schema) -> Result<bool> {
1978-
Ok(self.nullable)
1979-
}
1980-
1981-
fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
1982-
self.bounds_check(batch.schema().as_ref())?;
1983-
Ok(ColumnarValue::Array(Arc::clone(batch.column(self.index))))
1984-
}
1985-
1986-
fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
1987-
vec![]
1988-
}
1989-
1990-
fn with_new_children(
1991-
self: Arc<Self>,
1992-
_children: Vec<Arc<dyn PhysicalExpr>>,
1993-
) -> Result<Arc<dyn PhysicalExpr>> {
1994-
Ok(self)
1995-
}
1996-
1997-
fn fmt_sql(&self, _: &mut Formatter<'_>) -> std::fmt::Result {
1998-
unimplemented!()
1999-
}
2000-
}
2001-
2002-
impl CustomColumn {
2003-
fn bounds_check(&self, input_schema: &Schema) -> Result<()> {
2004-
if self.index < input_schema.fields.len() {
2005-
Ok(())
2006-
} else {
2007-
internal_err!(
2008-
"PhysicalExpr BoundLambdaColumn references column '{}' at index {} (zero-based) but input schema only has {} columns: {:?}",
2009-
self.name,
2010-
self.index,
2011-
input_schema.fields.len(),
2012-
input_schema
2013-
.fields()
2014-
.iter()
2015-
.map(|f| f.name())
2016-
.collect::<Vec<_>>()
2017-
)
2018-
}
2019-
}
2020-
}
2021-
2022-
let batch = case_test_batch()?;
2023-
let schema = batch.schema();
2024-
2025-
// CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN 456 END
2026-
let when1 = binary(
2027-
CustomColumn::new_with_schema("a", &schema)?,
2028-
Operator::Eq,
2029-
lit("foo"),
2030-
&batch.schema(),
2031-
)?;
2032-
let then1 = lit(123i32);
2033-
let when2 = binary(
2034-
CustomColumn::new_with_schema("a", &schema)?,
2035-
Operator::Eq,
2036-
lit("bar"),
2037-
&batch.schema(),
2038-
)?;
2039-
let then2 = lit(456i32);
2040-
2041-
let expr = generate_case_when_with_type_coercion(
2042-
None,
2043-
vec![(when1, then1), (when2, then2)],
2044-
None,
2045-
schema.as_ref(),
2046-
)?;
2047-
let result = expr
2048-
.evaluate(&batch)?
2049-
.into_array(batch.num_rows())
2050-
.expect("Failed to convert to array");
2051-
let result = as_int32_array(&result)?;
2052-
2053-
let expected = &Int32Array::from(vec![Some(123), None, None, Some(456)]);
2054-
2055-
assert_eq!(expected, result);
2056-
2057-
Ok(())
2058-
}
2059-
20601933
#[test]
20611934
fn case_with_expr_when_null() -> Result<()> {
20621935
let batch = case_test_batch()?;
@@ -2552,7 +2425,7 @@ mod tests {
25522425
.unwrap();
25532426

25542427
let expr3 = Arc::clone(&expr)
2555-
.transform_down(|e| {
2428+
.transform_down_in_scope(|e| {
25562429
let transformed = match e.as_any().downcast_ref::<Literal>() {
25572430
Some(lit_value) => match lit_value.value() {
25582431
ScalarValue::Utf8(Some(str_value)) => {

datafusion/physical-expr/src/physical_expr.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ use crate::{LexOrdering, PhysicalSortExpr, create_physical_expr};
2222

2323
use arrow::compute::SortOptions;
2424
use arrow::datatypes::{Schema, SchemaRef};
25-
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
25+
use datafusion_common::tree_node::{ScopedTreeNode, Transformed, TransformedResult, TreeNode};
2626
use datafusion_common::{DFSchema, HashMap};
2727
use datafusion_common::{Result, plan_err};
2828
use datafusion_expr::execution_props::ExecutionProps;
@@ -38,7 +38,7 @@ pub fn add_offset_to_expr(
3838
expr: Arc<dyn PhysicalExpr>,
3939
offset: isize,
4040
) -> Result<Arc<dyn PhysicalExpr>> {
41-
expr.transform_down(|e| match e.as_any().downcast_ref::<Column>() {
41+
expr.transform_down_in_scope(|e| match e.as_any().downcast_ref::<Column>() {
4242
Some(col) => {
4343
let Some(idx) = col.index().checked_add_signed(offset) else {
4444
return plan_err!("Column index overflow");

datafusion/physical-expr/src/projection.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ use crate::utils::collect_columns;
2727
use arrow::array::{RecordBatch, RecordBatchOptions};
2828
use arrow::datatypes::{Field, Schema, SchemaRef};
2929
use datafusion_common::stats::{ColumnStatistics, Precision};
30-
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
30+
use datafusion_common::tree_node::{ScopedTreeNode, Transformed, TransformedResult, TreeNode};
3131
use datafusion_common::{
3232
Result, ScalarValue, Statistics, assert_or_internal_err, internal_datafusion_err,
3333
plan_err,
@@ -920,7 +920,7 @@ pub fn update_expr(
920920
let mut state = RewriteState::Unchanged;
921921

922922
let new_expr = Arc::clone(expr)
923-
.transform_up(|expr| {
923+
.transform_up_in_scope(|expr| {
924924
if state == RewriteState::RewrittenInvalid {
925925
return Ok(Transformed::no(expr));
926926
}
@@ -1043,7 +1043,7 @@ impl ProjectionMapping {
10431043
let mut map = IndexMap::<_, ProjectionTargets>::new();
10441044
for (expr_idx, (expr, name)) in expr.into_iter().enumerate() {
10451045
let target_expr = Arc::new(Column::new(&name, expr_idx)) as _;
1046-
let source_expr = expr.transform_down(|e| match e.as_any().downcast_ref::<Column>() {
1046+
let source_expr = expr.transform_down_in_scope(|e| match e.as_any().downcast_ref::<Column>() {
10471047
Some(col) => {
10481048
// Sometimes, an expression and its name in the input_schema
10491049
// doesn't match. This can cause problems, so we make sure
@@ -1162,7 +1162,7 @@ pub fn project_ordering(
11621162
) -> Option<LexOrdering> {
11631163
let mut projected_exprs = vec![];
11641164
for PhysicalSortExpr { expr, options } in ordering.iter() {
1165-
let transformed = Arc::clone(expr).transform_up(|expr| {
1165+
let transformed = Arc::clone(expr).transform_up_in_scope(|expr| {
11661166
let Some(col) = expr.as_any().downcast_ref::<Column>() else {
11671167
return Ok(Transformed::no(expr));
11681168
};

datafusion/physical-expr/src/utils/mod.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
// under the License.
1717

1818
mod guarantee;
19+
use datafusion_common::tree_node::ScopedTreeNode;
1920
pub use guarantee::{Guarantee, LiteralGuarantee};
2021

2122
use std::borrow::Borrow;
@@ -251,7 +252,7 @@ pub fn reassign_expr_columns(
251252
expr: Arc<dyn PhysicalExpr>,
252253
schema: &Schema,
253254
) -> Result<Arc<dyn PhysicalExpr>> {
254-
expr.transform_down(|expr| {
255+
expr.transform_down_in_scope(|expr| {
255256
if let Some(column) = expr.as_any().downcast_ref::<Column>() {
256257
let index = schema.index_of(column.name())?;
257258

0 commit comments

Comments
 (0)