diff --git a/.github/workflows/large_files.yml b/.github/workflows/large_files.yml index 12a5599189216..63cf18027c80f 100644 --- a/.github/workflows/large_files.yml +++ b/.github/workflows/large_files.yml @@ -34,9 +34,9 @@ jobs: fetch-depth: 0 - name: Check size of new Git objects env: - # 1 MB ought to be enough for anybody. + # 2 MB ought to be enough for anybody. # TODO in case we may want to consciously commit a bigger file to the repo without using Git LFS we may disable the check e.g. with a label - MAX_FILE_SIZE_BYTES: 1048576 + MAX_FILE_SIZE_BYTES: 2097152 shell: bash run: | if [ "${{ github.event_name }}" = "merge_group" ]; then diff --git a/Cargo.lock b/Cargo.lock index 5a76c063bbfad..513b0420ae062 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1802,6 +1802,7 @@ dependencies = [ "flate2", "futures", "glob", + "indexmap 2.14.0", "insta", "itertools 0.14.0", "liblzma", diff --git a/datafusion-examples/examples/custom_data_source/adapter_serialization.rs b/datafusion-examples/examples/custom_data_source/adapter_serialization.rs index 22f8589d6f986..d82bd2097ce1d 100644 --- a/datafusion-examples/examples/custom_data_source/adapter_serialization.rs +++ b/datafusion-examples/examples/custom_data_source/adapter_serialization.rs @@ -17,9 +17,9 @@ //! See `main.rs` for how to run it. //! -//! This example demonstrates how to use the `PhysicalExtensionCodec` trait's -//! interception methods (`serialize_physical_plan` and `deserialize_physical_plan`) -//! to implement custom serialization logic. +//! This example demonstrates how to use the `PhysicalProtoConverterExtension` +//! trait's interception methods (`execution_plan_to_proto` and +//! `proto_to_execution_plan`) to implement custom serialization logic. //! //! The key insight is that `FileScanConfig::expr_adapter_factory` is NOT serialized by //! default. This example shows how to: @@ -28,9 +28,10 @@ //! 3. Store the inner DataSourceExec (without adapter) as a child in the extension's inputs field //! 4. Unwrap and restore the adapter during deserialization //! -//! This demonstrates nested serialization (protobuf outer, JSON inner) and the power -//! of the `PhysicalExtensionCodec` interception pattern. Both plan and expression -//! serialization route through the codec, enabling interception at every node in the tree. +//! This demonstrates nested serialization (protobuf outer, JSON inner) and the +//! power of `PhysicalProtoConverterExtension`. Both plan and expression +//! serialization route through converter hooks, enabling interception at every +//! node in the tree. use std::fmt::Debug; use std::sync::Arc; @@ -61,7 +62,7 @@ use datafusion_proto::bytes::{ use datafusion_proto::physical_plan::from_proto::parse_physical_expr_with_converter; use datafusion_proto::physical_plan::to_proto::serialize_physical_expr_with_converter; use datafusion_proto::physical_plan::{ - PhysicalExtensionCodec, PhysicalProtoConverterExtension, + PhysicalExtensionCodec, PhysicalPlanDecodeContext, PhysicalProtoConverterExtension, }; use datafusion_proto::protobuf::physical_plan_node::PhysicalPlanType; use datafusion_proto::protobuf::{ @@ -177,7 +178,7 @@ pub async fn adapter_serialization() -> Result<()> { println!("\n=== Example Complete! ==="); println!("Key takeaways:"); println!( - " 1. PhysicalExtensionCodec provides serialize_physical_plan/deserialize_physical_plan hooks" + " 1. PhysicalProtoConverterExtension provides execution_plan_to_proto/proto_to_execution_plan hooks" ); println!(" 2. Custom metadata can be wrapped as PhysicalExtensionNode"); println!(" 3. Nested serialization (protobuf + JSON) works seamlessly"); @@ -303,9 +304,10 @@ impl PhysicalExtensionCodec for AdapterPreservingCodec { _node: Arc, _buf: &mut Vec, ) -> Result<()> { - // We don't need this for the example - we use serialize_physical_plan instead + // We don't need this for the example - adapter wrapping happens in + // `execution_plan_to_proto` instead. not_impl_err!( - "try_encode not used - adapter wrapping happens in serialize_physical_plan" + "try_encode not used - adapter wrapping happens in execution_plan_to_proto" ) } } @@ -371,9 +373,8 @@ impl PhysicalProtoConverterExtension for AdapterPreservingCodec { // Interception point: override deserialization to unwrap adapters fn proto_to_execution_plan( &self, - ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, proto: &PhysicalPlanNode, + ctx: &PhysicalPlanDecodeContext<'_>, ) -> Result> { // Check if this is our custom extension wrapper if let Some(PhysicalPlanType::Extension(extension)) = &proto.physical_plan_type @@ -395,11 +396,7 @@ impl PhysicalProtoConverterExtension for AdapterPreservingCodec { let inner_proto = &extension.inputs[0]; // Deserialize the inner plan - let inner_plan = inner_proto.try_into_physical_plan_with_converter( - ctx, - extension_codec, - self, - )?; + let inner_plan = self.default_proto_to_execution_plan(inner_proto, ctx)?; // Recreate the adapter factory let adapter_factory = create_adapter_factory(&payload.adapter_metadata.tag); @@ -409,17 +406,16 @@ impl PhysicalProtoConverterExtension for AdapterPreservingCodec { } // Not our extension - use default deserialization - proto.try_into_physical_plan_with_converter(ctx, extension_codec, self) + self.default_proto_to_execution_plan(proto, ctx) } fn proto_to_physical_expr( &self, proto: &PhysicalExprNode, - ctx: &TaskContext, input_schema: &Schema, - codec: &dyn PhysicalExtensionCodec, + ctx: &PhysicalPlanDecodeContext<'_>, ) -> Result> { - parse_physical_expr_with_converter(proto, ctx, input_schema, codec, self) + parse_physical_expr_with_converter(proto, input_schema, ctx, self) } fn physical_expr_to_proto( diff --git a/datafusion-examples/examples/proto/expression_deduplication.rs b/datafusion-examples/examples/proto/expression_deduplication.rs index ec076b52498cc..26d246b2efca8 100644 --- a/datafusion-examples/examples/proto/expression_deduplication.rs +++ b/datafusion-examples/examples/proto/expression_deduplication.rs @@ -17,8 +17,9 @@ //! See `main.rs` for how to run it. //! -//! This example demonstrates how to use the `PhysicalExtensionCodec` trait's -//! interception methods to implement expression deduplication during deserialization. +//! This example demonstrates how to use the +//! `PhysicalProtoConverterExtension` trait's interception methods to +//! implement expression deduplication during deserialization. //! //! This pattern is inspired by PR #18192, which introduces expression caching //! to reduce memory usage when deserializing plans with duplicate expressions. @@ -29,8 +30,9 @@ //! 2. Reduce memory allocation during deserialization //! 3. Enable downstream optimizations that rely on Arc pointer equality //! -//! This demonstrates the decorator pattern enabled by the `PhysicalExtensionCodec` trait, -//! where all expression serialization/deserialization routes through the codec methods. +//! This demonstrates the decorator pattern enabled by +//! `PhysicalProtoConverterExtension`, where physical-expression +//! serialization and deserialization route through converter hooks. use std::collections::HashMap; use std::fmt::Debug; @@ -49,7 +51,7 @@ use datafusion::prelude::SessionContext; use datafusion_proto::physical_plan::from_proto::parse_physical_expr_with_converter; use datafusion_proto::physical_plan::to_proto::serialize_physical_expr_with_converter; use datafusion_proto::physical_plan::{ - DefaultPhysicalExtensionCodec, PhysicalExtensionCodec, + DefaultPhysicalExtensionCodec, PhysicalExtensionCodec, PhysicalPlanDecodeContext, PhysicalProtoConverterExtension, }; use datafusion_proto::protobuf::{PhysicalExprNode, PhysicalPlanNode}; @@ -202,11 +204,10 @@ impl PhysicalExtensionCodec for CachingCodec { impl PhysicalProtoConverterExtension for CachingCodec { fn proto_to_execution_plan( &self, - ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, proto: &PhysicalPlanNode, + ctx: &PhysicalPlanDecodeContext<'_>, ) -> Result> { - proto.try_into_physical_plan_with_converter(ctx, extension_codec, self) + self.default_proto_to_execution_plan(proto, ctx) } fn execution_plan_to_proto( @@ -225,9 +226,8 @@ impl PhysicalProtoConverterExtension for CachingCodec { fn proto_to_physical_expr( &self, proto: &PhysicalExprNode, - ctx: &TaskContext, input_schema: &Schema, - codec: &dyn PhysicalExtensionCodec, + ctx: &PhysicalPlanDecodeContext<'_>, ) -> Result> { // Create cache key from protobuf bytes let mut key = Vec::new(); @@ -249,8 +249,7 @@ impl PhysicalProtoConverterExtension for CachingCodec { } // Cache miss - deserialize and store - let expr = - parse_physical_expr_with_converter(proto, ctx, input_schema, codec, self)?; + let expr = parse_physical_expr_with_converter(proto, input_schema, ctx, self)?; // Store in cache { diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index a2a07d4598b0a..0a3783de9de79 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -144,6 +144,7 @@ datafusion-session = { workspace = true } datafusion-sql = { workspace = true, optional = true } flate2 = { workspace = true, optional = true } futures = { workspace = true } +indexmap = { workspace = true } itertools = { workspace = true } liblzma = { workspace = true, optional = true } log = { workspace = true } diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index a025446aa37e8..3b2c7a78e898e 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -63,6 +63,7 @@ use arrow::datatypes::Schema; use arrow_schema::Field; use datafusion_catalog::ScanArgs; use datafusion_common::Column; +use datafusion_common::HashMap as DFHashMap; use datafusion_common::display::ToStringifiedPlan; use datafusion_common::format::ExplainAnalyzeCategories; use datafusion_common::tree_node::{ @@ -78,11 +79,13 @@ use datafusion_common::{ use datafusion_datasource::file_groups::FileGroup; use datafusion_datasource::memory::MemorySourceConfig; use datafusion_expr::dml::{CopyTo, InsertOp}; +use datafusion_expr::execution_props::{ScalarSubqueryResults, SubqueryIndex}; use datafusion_expr::expr::{ AggregateFunction, AggregateFunctionParams, Alias, GroupingSet, NullTreatment, WindowFunction, WindowFunctionParams, physical_name, }; use datafusion_expr::expr_rewriter::unnormalize_cols; +use datafusion_expr::logical_plan::Subquery; use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary; use datafusion_expr::utils::{expr_to_columns, split_conjunction}; use datafusion_expr::{ @@ -101,11 +104,13 @@ use datafusion_physical_plan::execution_plan::InvariantLevel; use datafusion_physical_plan::joins::PiecewiseMergeJoinExec; use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; use datafusion_physical_plan::recursive_query::RecursiveQueryExec; +use datafusion_physical_plan::scalar_subquery::{ScalarSubqueryExec, ScalarSubqueryLink}; use datafusion_physical_plan::unnest::ListUnnest; use async_trait::async_trait; use datafusion_physical_plan::async_func::{AsyncFuncExec, AsyncMapper}; use futures::{StreamExt, TryStreamExt}; +use indexmap::IndexSet; use itertools::{Itertools, multiunzip}; use log::debug; use tokio::sync::Mutex; @@ -379,8 +384,93 @@ impl DefaultPhysicalPlanner { Ok(()) } - /// Create a physical plan from a logical plan - async fn create_initial_plan( + /// Collect uncorrelated scalar subqueries. We don't descend into nested + /// subqueries here: each call to `create_initial_plan` handles subqueries + /// at its level and then recurses in order to handle nested subqueries. + #[allow(clippy::allow_attributes, clippy::mutable_key_type)] // Subquery contains Arc with interior mutability but is intentionally used as hash key + fn collect_scalar_subqueries(plan: &LogicalPlan) -> Vec { + let mut subqueries = IndexSet::new(); + plan.apply(|node| { + for expr in node.expressions() { + expr.apply(|e| { + if let Expr::ScalarSubquery(sq) = e + && sq.outer_ref_columns.is_empty() + { + subqueries.insert(sq.clone()); + } + Ok(TreeNodeRecursion::Continue) + }) + .expect("infallible"); + } + Ok(TreeNodeRecursion::Continue) + }) + .expect("infallible"); + subqueries.into_iter().collect() + } + + /// Create a physical plan from a logical plan. + /// + /// Uncorrelated scalar subqueries in the plan's own expressions are + /// collected, planned as separate physical plans, and each assigned an + /// index in a shared [`ScalarSubqueryResults`] container that will hold its + /// result at execution time. The index map and shared results container are + /// registered in [`ExecutionProps`] so that [`create_physical_expr`] can + /// convert `Expr::ScalarSubquery` into [`ScalarSubqueryExpr`] nodes that + /// read from that container. + /// + /// The resulting physical plan is wrapped in a [`ScalarSubqueryExec`] node + /// that executes those subquery plans before any data flows through the + /// main plan. If a subquery itself contains nested uncorrelated subqueries, + /// the recursive call produces its own [`ScalarSubqueryExec`] inside the + /// subquery plan — each level manages only its own subqueries. + /// + /// Returns a [`BoxFuture`] rather than using `async fn` because of + /// this recursion. + /// + /// [`ScalarSubqueryExpr`]: datafusion_physical_expr::scalar_subquery::ScalarSubqueryExpr + /// [`BoxFuture`]: futures::future::BoxFuture + fn create_initial_plan<'a>( + &'a self, + logical_plan: &'a LogicalPlan, + session_state: &'a SessionState, + ) -> futures::future::BoxFuture<'a, Result>> { + Box::pin(async move { + let all_subqueries = Self::collect_scalar_subqueries(logical_plan); + let (links, index_map) = self + .plan_scalar_subqueries(all_subqueries, session_state) + .await?; + + if links.is_empty() { + return self + .create_initial_plan_inner(logical_plan, session_state) + .await; + } + + // Create the shared `ScalarSubqueryResults` container and register + // it in `ExecutionProps` so that `create_physical_expr` can resolve + // `Expr::ScalarSubquery` into `ScalarSubqueryExpr` nodes. We clone + // the `SessionState` so these are available throughout physical + // planning without mutating the caller's state. + // + // Ideally, the subquery state would live in a dedicated planning + // context rather than in `ExecutionProps`. It's here because + // `create_physical_expr` only receives `&ExecutionProps`. + let results = ScalarSubqueryResults::new(links.len()); + let mut owned = session_state.clone(); + owned.execution_props_mut().subquery_indexes = index_map; + owned.execution_props_mut().subquery_results = results.clone(); + let session_state = Cow::Owned(owned); + + let plan = self + .create_initial_plan_inner(logical_plan, &session_state) + .await?; + Ok(Arc::new(ScalarSubqueryExec::new(plan, links, results))) + }) + } + + /// Inner physical planning that converts a logical plan tree into an + /// execution plan tree without collecting scalar subqueries. + async fn create_initial_plan_inner( &self, logical_plan: &LogicalPlan, session_state: &SessionState, @@ -545,6 +635,7 @@ impl DefaultPhysicalPlanner { session_state: &SessionState, children: ChildrenContainer, ) -> Result> { + let execution_props = session_state.execution_props(); let exec_node: Arc = match node { // Leaves (no children) LogicalPlan::TableScan(scan) => { @@ -601,7 +692,7 @@ impl DefaultPhysicalPlanner { .map(|row| { row.iter() .map(|expr| { - self.create_physical_expr(expr, schema, session_state) + create_physical_expr(expr, schema, execution_props) }) .collect::>>>() }) @@ -860,13 +951,7 @@ impl DefaultPhysicalPlanner { let logical_schema = node.schema(); let window_expr = window_expr .iter() - .map(|e| { - create_window_expr( - e, - logical_schema, - session_state.execution_props(), - ) - }) + .map(|e| create_window_expr(e, logical_schema, execution_props)) .collect::>>()?; let can_repartition = session_state.config().target_partitions() > 1 @@ -971,7 +1056,7 @@ impl DefaultPhysicalPlanner { group_expr, logical_input_schema, &physical_input_schema, - session_state, + execution_props, )?; let agg_filter = aggr_expr @@ -981,7 +1066,7 @@ impl DefaultPhysicalPlanner { e, logical_input_schema, &physical_input_schema, - session_state.execution_props(), + execution_props, ) }) .collect::>>()?; @@ -1075,8 +1160,8 @@ impl DefaultPhysicalPlanner { )?) } LogicalPlan::Projection(Projection { input, expr, .. }) => self - .create_project_physical_exec( - session_state, + .create_project_physical_exec_with_props( + execution_props, children.one()?, input, expr, @@ -1086,9 +1171,8 @@ impl DefaultPhysicalPlanner { }) => { let physical_input = children.one()?; let input_dfschema = input.schema(); - let runtime_expr = - self.create_physical_expr(predicate, input_dfschema, session_state)?; + create_physical_expr(predicate, input_dfschema, execution_props)?; let input_schema = input.schema(); let filter = match self.try_plan_async_exprs( @@ -1136,7 +1220,9 @@ impl DefaultPhysicalPlanner { .options() .optimizer .default_filter_selectivity; - Arc::new(filter.with_default_selectivity(selectivity)?) + let filter_exec: Arc = + Arc::new(filter.with_default_selectivity(selectivity)?); + filter_exec } LogicalPlan::Repartition(Repartition { input, @@ -1152,11 +1238,7 @@ impl DefaultPhysicalPlanner { let runtime_expr = expr .iter() .map(|e| { - self.create_physical_expr( - e, - input_dfschema, - session_state, - ) + create_physical_expr(e, input_dfschema, execution_props) }) .collect::>>()?; Partitioning::Hash(runtime_expr, *n) @@ -1177,11 +1259,8 @@ impl DefaultPhysicalPlanner { }) => { let physical_input = children.one()?; let input_dfschema = input.as_ref().schema(); - let sort_exprs = create_physical_sort_exprs( - expr, - input_dfschema, - session_state.execution_props(), - )?; + let sort_exprs = + create_physical_sort_exprs(expr, input_dfschema, execution_props)?; let Some(ordering) = LexOrdering::new(sort_exprs) else { return internal_err!( "SortExec requires at least one sort expression" @@ -1308,8 +1387,8 @@ impl DefaultPhysicalPlanner { ( true, LogicalPlan::Projection(Projection { input, expr, .. }), - ) => self.create_project_physical_exec( - session_state, + ) => self.create_project_physical_exec_with_props( + execution_props, physical_left, input, expr, @@ -1321,8 +1400,8 @@ impl DefaultPhysicalPlanner { ( true, LogicalPlan::Projection(Projection { input, expr, .. }), - ) => self.create_project_physical_exec( - session_state, + ) => self.create_project_physical_exec_with_props( + execution_props, physical_right, input, expr, @@ -1387,7 +1466,6 @@ impl DefaultPhysicalPlanner { // All equi-join keys are columns now, create physical join plan let left_df_schema = left.schema(); let right_df_schema = right.schema(); - let execution_props = session_state.execution_props(); let join_on = keys .iter() .map(|(l, r)| { @@ -1493,7 +1571,7 @@ impl DefaultPhysicalPlanner { let filter_expr = create_physical_expr( expr, &filter_df_schema, - session_state.execution_props(), + execution_props, )?; let column_indices = join_utils::JoinFilter::build_column_indices( left_field_indices, @@ -1613,12 +1691,12 @@ impl DefaultPhysicalPlanner { let on_left = create_physical_expr( lhs_logical, left_df_schema, - session_state.execution_props(), + execution_props, )?; let on_right = create_physical_expr( rhs_logical, right_df_schema, - session_state.execution_props(), + execution_props, )?; Arc::new(PiecewiseMergeJoinExec::try_new( @@ -1688,7 +1766,12 @@ impl DefaultPhysicalPlanner { // If plan was mutated previously then need to create the ExecutionPlan // for the new Projection that was applied on top. if let Some((input, expr)) = new_project { - self.create_project_physical_exec(session_state, join, input, expr)? + self.create_project_physical_exec_with_props( + execution_props, + join, + input, + expr, + )? } else { join } @@ -1782,7 +1865,7 @@ impl DefaultPhysicalPlanner { group_expr: &[Expr], input_dfschema: &DFSchema, input_schema: &Schema, - session_state: &SessionState, + execution_props: &ExecutionProps, ) -> Result { if group_expr.len() == 1 { match &group_expr[0] { @@ -1791,25 +1874,25 @@ impl DefaultPhysicalPlanner { grouping_sets, input_dfschema, input_schema, - session_state, + execution_props, ) } Expr::GroupingSet(GroupingSet::Cube(exprs)) => create_cube_physical_expr( exprs, input_dfschema, input_schema, - session_state, + execution_props, ), Expr::GroupingSet(GroupingSet::Rollup(exprs)) => { create_rollup_physical_expr( exprs, input_dfschema, input_schema, - session_state, + execution_props, ) } expr => Ok(PhysicalGroupBy::new_single(vec![tuple_err(( - self.create_physical_expr(expr, input_dfschema, session_state), + create_physical_expr(expr, input_dfschema, execution_props), physical_name(expr), ))?])), } @@ -1823,7 +1906,7 @@ impl DefaultPhysicalPlanner { .iter() .map(|e| { tuple_err(( - self.create_physical_expr(e, input_dfschema, session_state), + create_physical_expr(e, input_dfschema, execution_props), physical_name(e), )) }) @@ -1847,7 +1930,7 @@ fn merge_grouping_set_physical_expr( grouping_sets: &[Vec], input_dfschema: &DFSchema, input_schema: &Schema, - session_state: &SessionState, + execution_props: &ExecutionProps, ) -> Result { let num_groups = grouping_sets.len(); let mut all_exprs: Vec = vec![]; @@ -1861,14 +1944,14 @@ fn merge_grouping_set_physical_expr( grouping_set_expr.push(get_physical_expr_pair( expr, input_dfschema, - session_state, + execution_props, )?); null_exprs.push(get_null_physical_expr_pair( expr, input_dfschema, input_schema, - session_state, + execution_props, )?); } } @@ -1898,7 +1981,7 @@ fn create_cube_physical_expr( exprs: &[Expr], input_dfschema: &DFSchema, input_schema: &Schema, - session_state: &SessionState, + execution_props: &ExecutionProps, ) -> Result { let num_of_exprs = exprs.len(); let num_groups = num_of_exprs * num_of_exprs; @@ -1913,10 +1996,14 @@ fn create_cube_physical_expr( expr, input_dfschema, input_schema, - session_state, + execution_props, )?); - all_exprs.push(get_physical_expr_pair(expr, input_dfschema, session_state)?) + all_exprs.push(get_physical_expr_pair( + expr, + input_dfschema, + execution_props, + )?) } let mut groups: Vec> = Vec::with_capacity(num_groups); @@ -1940,7 +2027,7 @@ fn create_rollup_physical_expr( exprs: &[Expr], input_dfschema: &DFSchema, input_schema: &Schema, - session_state: &SessionState, + execution_props: &ExecutionProps, ) -> Result { let num_of_exprs = exprs.len(); @@ -1956,10 +2043,14 @@ fn create_rollup_physical_expr( expr, input_dfschema, input_schema, - session_state, + execution_props, )?); - all_exprs.push(get_physical_expr_pair(expr, input_dfschema, session_state)?) + all_exprs.push(get_physical_expr_pair( + expr, + input_dfschema, + execution_props, + )?) } for total in 0..=num_of_exprs { @@ -1984,10 +2075,9 @@ fn get_null_physical_expr_pair( expr: &Expr, input_dfschema: &DFSchema, input_schema: &Schema, - session_state: &SessionState, + execution_props: &ExecutionProps, ) -> Result<(Arc, String)> { - let physical_expr = - create_physical_expr(expr, input_dfschema, session_state.execution_props())?; + let physical_expr = create_physical_expr(expr, input_dfschema, execution_props)?; let physical_name = physical_name(&expr.clone())?; let data_type = physical_expr.data_type(input_schema)?; @@ -2056,10 +2146,9 @@ fn qualify_join_schema_sides( fn get_physical_expr_pair( expr: &Expr, input_dfschema: &DFSchema, - session_state: &SessionState, + execution_props: &ExecutionProps, ) -> Result<(Arc, String)> { - let physical_expr = - create_physical_expr(expr, input_dfschema, session_state.execution_props())?; + let physical_expr = create_physical_expr(expr, input_dfschema, execution_props)?; let physical_name = physical_name(expr)?; Ok((physical_expr, physical_name)) } @@ -2839,9 +2928,37 @@ impl DefaultPhysicalPlanner { Ok(mem_exec) } - fn create_project_physical_exec( + /// Build physical plans for scalar subqueries and assign each an ordinal + /// `SubqueryIndex`. Returns the links (plan + index) and a map from logical + /// `Subquery` to its index. + async fn plan_scalar_subqueries( &self, + subqueries: Vec, session_state: &SessionState, + ) -> Result<(Vec, DFHashMap)> { + let mut links = Vec::with_capacity(subqueries.len()); + let mut index_map = DFHashMap::with_capacity(subqueries.len()); + for sq in subqueries { + // Callers deduplicate, but guard against accidental double-planning. + if index_map.contains_key(&sq) { + continue; + } + let physical_plan = self + .create_initial_plan(&sq.subquery, session_state) + .await?; + let index = SubqueryIndex::new(links.len()); + links.push(ScalarSubqueryLink { + plan: physical_plan, + index, + }); + index_map.insert(sq, index); + } + Ok((links, index_map)) + } + + fn create_project_physical_exec_with_props( + &self, + execution_props: &ExecutionProps, input_exec: Arc, input: &Arc, expr: &[Expr], @@ -2880,7 +2997,7 @@ impl DefaultPhysicalPlanner { }; let physical_expr = - self.create_physical_expr(e, input_logical_schema, session_state); + create_physical_expr(e, input_logical_schema, execution_props); tuple_err((physical_expr, physical_name)) }) @@ -3124,7 +3241,7 @@ mod tests { use arrow_schema::SchemaRef; use datafusion_common::config::ConfigOptions; use datafusion_common::{ - DFSchemaRef, TableReference, ToDFSchema as _, assert_contains, + DFSchemaRef, TableReference, ToDFSchema as _, assert_batches_eq, assert_contains, }; use datafusion_execution::TaskContext; use datafusion_execution::runtime_env::RuntimeEnv; @@ -3158,6 +3275,16 @@ mod tests { .await } + async fn plan_sql(query: &str) -> Result> { + let ctx = SessionContext::new(); + ctx.sql(query).await?.create_physical_plan().await + } + + async fn collect_sql(query: &str) -> Result> { + let ctx = SessionContext::new(); + ctx.sql(query).await?.collect().await + } + #[tokio::test] async fn test_all_operators() -> Result<()> { let logical_plan = test_csv_scan() @@ -3181,6 +3308,132 @@ mod tests { Ok(()) } + #[tokio::test] + async fn scalar_subquery_in_sort_expr_plans() -> Result<()> { + let plan = plan_sql( + "SELECT x \ + FROM (VALUES (2), (1)) AS t(x) \ + ORDER BY x + (SELECT max(y) FROM (VALUES (10), (20)) AS u(y))", + ) + .await?; + + assert_contains!(format!("{plan:?}"), "ScalarSubqueryExec"); + Ok(()) + } + + #[tokio::test] + async fn scalar_subquery_in_sort_expr_executes() -> Result<()> { + let batches = collect_sql( + "SELECT x \ + FROM (VALUES (2), (1), (3)) AS t(x) \ + ORDER BY x + (SELECT max(y) FROM (VALUES (10), (20)) AS u(y)) DESC", + ) + .await?; + + assert_batches_eq!( + &[ + "+---+", "| x |", "+---+", "| 3 |", "| 2 |", "| 1 |", "+---+", + ], + &batches + ); + Ok(()) + } + + #[tokio::test] + async fn scalar_subquery_in_aggregate_arg_plans() -> Result<()> { + let plan = plan_sql( + "SELECT sum(x + (SELECT max(y) FROM (VALUES (10), (20)) AS u(y))) \ + FROM (VALUES (2), (1)) AS t(x)", + ) + .await?; + + assert_contains!(format!("{plan:?}"), "ScalarSubqueryExec"); + Ok(()) + } + + #[tokio::test] + async fn scalar_subquery_in_aggregate_arg_executes() -> Result<()> { + let batches = collect_sql( + "SELECT sum(x + (SELECT max(y) FROM (VALUES (10), (20)) AS u(y))) AS s \ + FROM (VALUES (2), (1)) AS t(x)", + ) + .await?; + + assert_batches_eq!( + &["+----+", "| s |", "+----+", "| 43 |", "+----+",], + &batches + ); + Ok(()) + } + + #[tokio::test] + async fn scalar_subquery_in_join_on_plans() -> Result<()> { + let plan = plan_sql( + "SELECT l.x, r.y \ + FROM (VALUES (1), (2)) AS l(x) \ + JOIN (VALUES (11), (12)) AS r(y) \ + ON l.x + (SELECT 10) = r.y", + ) + .await?; + + let formatted = format!("{plan:?}"); + assert_contains!(&formatted, "ScalarSubqueryExec"); + assert!( + formatted.contains("HashJoinExec") + || formatted.contains("SortMergeJoinExec") + || formatted.contains("NestedLoopJoinExec") + ); + Ok(()) + } + + #[tokio::test] + async fn scalar_subquery_mixed_correlated_and_uncorrelated_executes() -> Result<()> { + let query = "SELECT t.x, \ + (SELECT max(y) FROM (VALUES (10), (20)) AS u(y)) + \ + (SELECT count(*) FROM (VALUES (1), (1), (2)) AS v(z) WHERE v.z = t.x) AS total \ + FROM (VALUES (1), (2), (3)) AS t(x) \ + ORDER BY x"; + let plan = plan_sql(query).await?; + + let formatted = format!("{plan:?}"); + assert_eq!(formatted.matches("ScalarSubqueryExec").count(), 1); + assert!( + formatted.contains("HashJoinExec") + || formatted.contains("SortMergeJoinExec") + || formatted.contains("NestedLoopJoinExec") + ); + + let batches = collect_sql(query).await?; + assert_batches_eq!( + &[ + "+---+-------+", + "| x | total |", + "+---+-------+", + "| 1 | 22 |", + "| 2 | 21 |", + "| 3 | 20 |", + "+---+-------+", + ], + &batches + ); + Ok(()) + } + + #[tokio::test] + async fn scalar_subquery_in_projection_and_filter_plans() -> Result<()> { + let plan = plan_sql( + "SELECT x + (SELECT max(y) FROM (VALUES (10), (20)) AS u(y)) \ + FROM (VALUES (2), (1)) AS t(x) \ + WHERE x > (SELECT min(y) FROM (VALUES (0), (1)) AS v(y))", + ) + .await?; + + let formatted = format!("{plan:?}"); + // All uncorrelated scalar subqueries are hoisted to a single root node. + assert_eq!(formatted.matches("ScalarSubqueryExec").count(), 1); + Ok(()) + } + #[tokio::test] async fn test_create_cube_expr() -> Result<()> { let logical_plan = test_csv_scan().await?.build()?; @@ -3198,7 +3451,7 @@ mod tests { &exprs, logical_input_schema, physical_input_schema, - &session_state, + session_state.execution_props(), ); insta::assert_debug_snapshot!(cube, @r#" @@ -3329,7 +3582,7 @@ mod tests { &exprs, logical_input_schema, physical_input_schema, - &session_state, + session_state.execution_props(), ); insta::assert_debug_snapshot!(rollup, @r#" diff --git a/datafusion/expr/src/execution_props.rs b/datafusion/expr/src/execution_props.rs index 3bf6978eb60ee..f16854f924e3a 100644 --- a/datafusion/expr/src/execution_props.rs +++ b/datafusion/expr/src/execution_props.rs @@ -18,9 +18,13 @@ use crate::var_provider::{VarProvider, VarType}; use chrono::{DateTime, Utc}; use datafusion_common::HashMap; +use datafusion_common::ScalarValue; use datafusion_common::alias::AliasGenerator; use datafusion_common::config::ConfigOptions; -use std::sync::Arc; +use datafusion_common::{Result, internal_err}; +use std::fmt; +use std::hash::{Hash, Hasher}; +use std::sync::{Arc, Mutex}; /// Holds per-query execution properties and data (such as statement /// starting timestamps). @@ -42,6 +46,12 @@ pub struct ExecutionProps { pub config_options: Option>, /// Providers for scalar variables pub var_providers: Option>>, + /// Maps each logical `Subquery` to its index in `subquery_results`. + /// Populated by the physical planner before calling `create_physical_expr`. + pub subquery_indexes: HashMap, + /// Shared results container for uncorrelated scalar subquery values. + /// Populated at execution time by `ScalarSubqueryExec`. + pub subquery_results: ScalarSubqueryResults, } impl Default for ExecutionProps { @@ -58,6 +68,8 @@ impl ExecutionProps { alias_generator: Arc::new(AliasGenerator::new()), config_options: None, var_providers: None, + subquery_indexes: HashMap::new(), + subquery_results: ScalarSubqueryResults::default(), } } @@ -85,8 +97,7 @@ impl ExecutionProps { &*self } - /// Registers a variable provider, returning the existing - /// provider, if any + /// Registers a variable provider, returning the existing provider, if any pub fn add_var_provider( &mut self, var_type: VarType, @@ -119,15 +130,149 @@ impl ExecutionProps { } } +/// Index of a scalar subquery within a [`ScalarSubqueryResults`] container. +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub struct SubqueryIndex(usize); + +impl SubqueryIndex { + /// Creates a new subquery index. + pub const fn new(index: usize) -> Self { + Self(index) + } + + /// Returns the underlying slot index. + pub const fn as_usize(self) -> usize { + self.0 + } +} + +/// Shared results container for uncorrelated scalar subqueries. +/// +/// Each entry corresponds to one scalar subquery, identified by its index. +/// Each slot is populated at execution time by `ScalarSubqueryExec`, read by +/// `ScalarSubqueryExpr` instances that share this container, and cleared when +/// the plan is reset for re-execution. +#[derive(Clone, Default)] +pub struct ScalarSubqueryResults { + slots: Arc>>>, +} + +impl ScalarSubqueryResults { + /// Creates a new shared results container with `n` empty slots. + pub fn new(n: usize) -> Self { + Self { + slots: Arc::new((0..n).map(|_| Mutex::new(None)).collect()), + } + } + + /// Returns the scalar value stored at `index`, if it has been populated. + pub fn get(&self, index: SubqueryIndex) -> Option { + let slot = self.slots.get(index.as_usize())?; + slot.lock().unwrap().clone() + } + + /// Stores `value` in the slot at `index`. + pub fn set(&self, index: SubqueryIndex, value: ScalarValue) -> Result<()> { + let Some(slot) = self.slots.get(index.as_usize()) else { + return internal_err!( + "ScalarSubqueryResults: result index {} is out of bounds", + index.as_usize() + ); + }; + + let mut slot = slot.lock().unwrap(); + if slot.is_some() { + return internal_err!( + "ScalarSubqueryResults: result for index {} was already populated", + index.as_usize() + ); + } + *slot = Some(value); + + Ok(()) + } + + /// Clears all populated results so the container can be reused. + pub fn clear(&self) { + for slot in self.slots.iter() { + *slot.lock().unwrap() = None; + } + } + + /// Returns true if `this` and `other` point to the same shared container. + pub fn ptr_eq(this: &Self, other: &Self) -> bool { + Arc::ptr_eq(&this.slots, &other.slots) + } +} + +impl fmt::Debug for ScalarSubqueryResults { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_list() + .entries(self.slots.iter().map(|slot| slot.lock().unwrap().clone())) + .finish() + } +} + +impl PartialEq for ScalarSubqueryResults { + fn eq(&self, other: &Self) -> bool { + Self::ptr_eq(self, other) + } +} + +impl Eq for ScalarSubqueryResults {} + +impl Hash for ScalarSubqueryResults { + fn hash(&self, state: &mut H) { + Arc::as_ptr(&self.slots).hash(state); + } +} + #[cfg(test)] mod test { use super::*; + #[test] fn debug() { let props = ExecutionProps::new(); assert_eq!( - "ExecutionProps { query_execution_start_time: None, alias_generator: AliasGenerator { next_id: 1 }, config_options: None, var_providers: None }", + "ExecutionProps { query_execution_start_time: None, alias_generator: AliasGenerator { next_id: 1 }, config_options: None, var_providers: None, subquery_indexes: {}, subquery_results: [] }", format!("{props:?}") ); } + + #[test] + fn scalar_subquery_results_set_and_get() -> Result<()> { + let results = ScalarSubqueryResults::new(1); + assert_eq!(results.get(SubqueryIndex::new(0)), None); + + results.set(SubqueryIndex::new(0), ScalarValue::Int32(Some(42)))?; + assert_eq!( + results.get(SubqueryIndex::new(0)), + Some(ScalarValue::Int32(Some(42))) + ); + assert!( + results + .set(SubqueryIndex::new(0), ScalarValue::Int32(Some(7))) + .is_err() + ); + + Ok(()) + } + + #[test] + fn scalar_subquery_results_clear() -> Result<()> { + let results = ScalarSubqueryResults::new(1); + results.set(SubqueryIndex::new(0), ScalarValue::Int32(Some(42)))?; + + results.clear(); + + assert_eq!(results.get(SubqueryIndex::new(0)), None); + results.set(SubqueryIndex::new(0), ScalarValue::Int32(Some(7)))?; + assert_eq!( + results.get(SubqueryIndex::new(0)), + Some(ScalarValue::Int32(Some(7))) + ); + + Ok(()) + } } diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 7dba4f0579b97..c3dffe553dc70 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -2059,6 +2059,12 @@ impl Expr { .expect("exists closure is infallible") } + /// Returns true if the expression contains a scalar subquery. + pub fn contains_scalar_subquery(&self) -> bool { + self.exists(|expr| Ok(matches!(expr, Expr::ScalarSubquery(_)))) + .expect("exists closure is infallible") + } + /// Returns true if the expression node is volatile, i.e. whether it can return /// different results when evaluated multiple times with the same input. /// Note: unlike [`Self::is_volatile`], this function does not consider inputs: diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index d86024295a061..d851e37479096 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -3771,6 +3771,7 @@ impl PartialOrd for Aggregate { /// index among identical entries. For example, if the same set appears three /// times, the ordinals are 0, 1, 2 and this function returns 2. /// Returns 0 when no grouping set is duplicated. +#[allow(clippy::allow_attributes, clippy::mutable_key_type)] // Expr contains Arc with interior mutability but is intentionally used as hash key fn max_grouping_set_duplicate_ordinal(group_expr: &[Expr]) -> usize { if let Some(Expr::GroupingSet(GroupingSet::GroupingSets(sets))) = group_expr.first() { let mut counts: HashMap<&[Expr], usize> = HashMap::new(); diff --git a/datafusion/expr/src/logical_plan/tree_node.rs b/datafusion/expr/src/logical_plan/tree_node.rs index a1285510da569..41a545372dac3 100644 --- a/datafusion/expr/src/logical_plan/tree_node.rs +++ b/datafusion/expr/src/logical_plan/tree_node.rs @@ -808,7 +808,7 @@ impl LogicalPlan { transform_down_up_with_subqueries_impl(self, &mut f_down, &mut f_up) } - /// Similarly to [`Self::apply`], calls `f` on this node and its inputs + /// Similarly to [`Self::apply`], calls `f` on this node and its inputs, /// including subqueries that may appear in expressions such as `IN (SELECT /// ...)`. pub fn apply_subqueries Result>( @@ -821,9 +821,7 @@ impl LogicalPlan { | Expr::InSubquery(InSubquery { subquery, .. }) | Expr::SetComparison(SetComparison { subquery, .. }) | Expr::ScalarSubquery(subquery) => { - // use a synthetic plan so the collector sees a - // LogicalPlan::Subquery (even though it is - // actually a Subquery alias) + // Wrap in LogicalPlan::Subquery to match f's signature f(&LogicalPlan::Subquery(subquery.clone())) } _ => Ok(TreeNodeRecursion::Continue), @@ -888,4 +886,18 @@ impl LogicalPlan { }) }) } + + /// Similar to [`Self::map_subqueries`], but only applies `f` to + /// uncorrelated subqueries (those with no outer column references). + pub fn map_uncorrelated_subqueries Result>>( + self, + mut f: F, + ) -> Result> { + self.map_subqueries(|subquery_plan| match &subquery_plan { + LogicalPlan::Subquery(sq) if sq.outer_ref_columns.is_empty() => { + f(subquery_plan) + } + _ => Ok(Transformed::no(subquery_plan)), + }) + } } diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 4213c23ccc897..c02ba602475f3 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -586,8 +586,12 @@ impl OptimizerRule for CommonSubexprEliminate { | LogicalPlan::Unnest(_) | LogicalPlan::RecursiveQuery(_) => { // This rule handles recursion itself in a `ApplyOrder::TopDown` like - // manner. - plan.map_children(|c| self.rewrite(c, config))? + // manner. Process uncorrelated subqueries in expressions + // (e.g., Expr::ScalarSubquery), then direct children. + plan.map_uncorrelated_subqueries(|c| self.rewrite(c, config))? + .transform_sibling(|plan| { + plan.map_children(|c| self.rewrite(c, config)) + })? } }; diff --git a/datafusion/optimizer/src/eliminate_cross_join.rs b/datafusion/optimizer/src/eliminate_cross_join.rs index 3cb0516a6d296..8306d4b54c256 100644 --- a/datafusion/optimizer/src/eliminate_cross_join.rs +++ b/datafusion/optimizer/src/eliminate_cross_join.rs @@ -212,7 +212,12 @@ fn rewrite_children( plan: LogicalPlan, config: &dyn OptimizerConfig, ) -> Result> { - let transformed_plan = plan.map_children(|input| optimizer.rewrite(input, config))?; + // Process uncorrelated subqueries in expressions, then direct children. + let transformed_plan = plan + .map_uncorrelated_subqueries(|input| optimizer.rewrite(input, config))? + .transform_sibling(|plan| { + plan.map_children(|input| optimizer.rewrite(input, config)) + })?; // recompute schema if the plan was transformed if transformed_plan.transformed { diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index 14badcf1435d5..af944abc6f0b4 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -135,9 +135,11 @@ fn optimize_projections( // their parents' required indices. match plan { LogicalPlan::Projection(proj) => { - return merge_consecutive_projections(proj)?.transform_data(|proj| { - rewrite_projection_given_requirements(proj, config, &indices) - }); + return merge_consecutive_projections(proj)? + .transform_data(|proj| { + rewrite_projection_given_requirements(proj, config, &indices) + })? + .transform_data(|plan| optimize_subqueries(plan, config)); } LogicalPlan::Aggregate(aggregate) => { // Split parent requirements to GROUP BY and aggregate sections: @@ -222,7 +224,8 @@ fn optimize_projections( new_aggr_expr, ) .map(LogicalPlan::Aggregate) - }); + })? + .transform_data(|plan| optimize_subqueries(plan, config)); } LogicalPlan::Window(window) => { let input_schema = Arc::clone(window.input.schema()); @@ -262,7 +265,8 @@ fn optimize_projections( .map(LogicalPlan::Window) .map(Transformed::yes) } - }); + })? + .transform_data(|plan| optimize_subqueries(plan, config)); } LogicalPlan::TableScan(table_scan) => { let TableScan { @@ -283,7 +287,8 @@ fn optimize_projections( let new_scan = TableScan::try_new(table_name, source, Some(projection), filters, fetch)?; - return Ok(Transformed::yes(LogicalPlan::TableScan(new_scan))); + return Transformed::yes(LogicalPlan::TableScan(new_scan)) + .transform_data(|plan| optimize_subqueries(plan, config)); } // Other node types are handled below _ => {} @@ -476,6 +481,9 @@ fn optimize_projections( ) })?; + let transformed_plan = + transformed_plan.transform_data(|plan| optimize_subqueries(plan, config))?; + // If any of the children are transformed, we need to potentially update the plan's schema if transformed_plan.transformed { transformed_plan.map_data(|plan| plan.recompute_schema()) @@ -484,8 +492,19 @@ fn optimize_projections( } } -/// Merges consecutive projections. -/// +/// Optimizes uncorrelated subquery plans embedded in expressions of the given +/// plan node (e.g., `Expr::ScalarSubquery`). `map_children` only visits direct +/// plan inputs, so subqueries must be handled separately. +fn optimize_subqueries( + plan: LogicalPlan, + config: &dyn OptimizerConfig, +) -> Result> { + plan.map_uncorrelated_subqueries(|subquery_plan| { + let indices = RequiredIndices::new_for_all_exprs(&subquery_plan); + optimize_projections(subquery_plan, config, indices) + }) +} + /// Given a projection `proj`, this function attempts to merge it with a previous /// projection if it exists and if merging is beneficial. Merging is considered /// beneficial when expressions in the current projection are non-trivial and diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index e6455099077c0..35b2e3e3b67de 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -1144,8 +1144,16 @@ impl OptimizerRule for PushDownFilter { LogicalPlan::TableScan(scan) => { let filter_predicates = split_conjunction(&filter.predicate); - let (volatile_filters, non_volatile_filters): (Vec<&Expr>, Vec<&Expr>) = + // Filters containing scalar subqueries cannot be pushed to + // providers because the subquery result is not available + // until execution time. + let (subquery_filters, pushdown_candidates): (Vec<&Expr>, Vec<&Expr>) = filter_predicates + .into_iter() + .partition(|pred| pred.contains_scalar_subquery()); + + let (volatile_filters, non_volatile_filters): (Vec<&Expr>, Vec<&Expr>) = + pushdown_candidates .into_iter() .partition(|pred| pred.is_volatile()); @@ -1178,11 +1186,13 @@ impl OptimizerRule for PushDownFilter { .cloned() .collect(); - // Compose predicates to be of `Unsupported` or `Inexact` pushdown type, and also include volatile filters + // Compose predicates to be of `Unsupported` or `Inexact` pushdown type, + // and also include volatile and subquery-containing filters let new_predicate: Vec = zip .filter(|(_, res)| res != &TableProviderFilterPushDown::Exact) .map(|(pred, _)| pred) .chain(volatile_filters) + .chain(subquery_filters) .cloned() .collect(); diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs index c54cd287dbb46..941fcffb798f5 100644 --- a/datafusion/optimizer/src/scalar_subquery_to_join.rs +++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! [`ScalarSubqueryToJoin`] rewriting scalar subquery filters to `JOIN`s +//! [`ScalarSubqueryToJoin`] rewriting correlated scalar subquery filters to `JOIN`s use std::collections::{BTreeSet, HashMap}; use std::sync::Arc; @@ -36,8 +36,8 @@ use datafusion_expr::logical_plan::{JoinType, Subquery}; use datafusion_expr::utils::conjunction; use datafusion_expr::{EmptyRelation, Expr, LogicalPlan, LogicalPlanBuilder, expr}; -/// Optimizer rule for rewriting subquery filters to joins -/// and places additional projection on top of the filter, to preserve +/// Optimizer rule that rewrites correlated scalar subquery filters to joins and +/// places an additional projection on top of the filter, to preserve the /// original schema. #[derive(Default, Debug)] pub struct ScalarSubqueryToJoin {} @@ -48,10 +48,15 @@ impl ScalarSubqueryToJoin { Self::default() } - /// Finds expressions that have a scalar subquery in them (and recurses when found) + /// Finds expressions that contain correlated scalar subqueries (and + /// recurses when found). /// /// # Arguments - /// * `predicate` - A conjunction to split and search + /// * `predicate` - A conjunction to split and search. + /// * `alias_gen` - Generator used to produce unique aliases for each + /// extracted scalar subquery (e.g. `__scalar_sq_1`, `__scalar_sq_2`). + /// Each subquery is replaced by a column reference using the generated + /// alias, and the same alias is later used to construct the join. /// /// Returns a tuple (subqueries, alias) fn extract_subquery_exprs( @@ -85,7 +90,7 @@ impl OptimizerRule for ScalarSubqueryToJoin { LogicalPlan::Filter(filter) => { // Optimization: skip the rest of the rule and its copies if // there are no scalar subqueries - if !contains_scalar_subquery(&filter.predicate) { + if !contains_correlated_scalar_subquery(&filter.predicate) { return Ok(Transformed::no(LogicalPlan::Filter(filter))); } @@ -137,9 +142,13 @@ impl OptimizerRule for ScalarSubqueryToJoin { Ok(Transformed::yes(new_plan)) } LogicalPlan::Projection(projection) => { - // Optimization: skip the rest of the rule and its copies if - // there are no scalar subqueries - if !projection.expr.iter().any(contains_scalar_subquery) { + // Optimization: skip the rest of the rule and its copies if there + // are no correlated scalar subqueries + if !projection + .expr + .iter() + .any(contains_correlated_scalar_subquery) + { return Ok(Transformed::no(LogicalPlan::Projection(projection))); } @@ -226,11 +235,14 @@ impl OptimizerRule for ScalarSubqueryToJoin { } } -/// Returns true if the expression has a scalar subquery somewhere in it -/// false otherwise -fn contains_scalar_subquery(expr: &Expr) -> bool { - expr.exists(|expr| Ok(matches!(expr, Expr::ScalarSubquery(_)))) - .expect("Inner is always Ok") +/// Returns true if the expression contains a correlated scalar subquery, false +/// otherwise. Uncorrelated scalar subqueries are handled by the physical +/// planner via `ScalarSubqueryExec` and do not need to be converted to joins. +fn contains_correlated_scalar_subquery(expr: &Expr) -> bool { + expr.exists(|expr| { + Ok(matches!(expr, Expr::ScalarSubquery(sq) if !sq.outer_ref_columns.is_empty())) + }) + .expect("Inner is always Ok") } struct ExtractScalarSubQuery<'a> { @@ -243,19 +255,21 @@ impl TreeNodeRewriter for ExtractScalarSubQuery<'_> { fn f_down(&mut self, expr: Expr) -> Result> { match expr { - Expr::ScalarSubquery(subquery) => { - let subqry_alias = self.alias_gen.next("__scalar_sq"); - self.sub_query_info - .push((subquery.clone(), subqry_alias.clone())); + // Skip uncorrelated scalar subqueries + Expr::ScalarSubquery(ref subquery) + if !subquery.outer_ref_columns.is_empty() => + { + let subquery = subquery.clone(); let scalar_expr = subquery .subquery .head_output_expr()? .map_or(plan_err!("single expression required."), Ok)?; + let subqry_alias = self.alias_gen.next("__scalar_sq"); + let col = + create_col_from_scalar_expr(&scalar_expr, subqry_alias.clone())?; + self.sub_query_info.push((subquery, subqry_alias)); Ok(Transformed::new( - Expr::Column(create_col_from_scalar_expr( - &scalar_expr, - subqry_alias, - )?), + Expr::Column(col), true, TreeNodeRecursion::Jump, )) @@ -627,15 +641,13 @@ mod tests { plan, @r" Projection: customer.c_custkey [c_custkey:Int64] - Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8] - Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N] - Left Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N] - TableScan: customer [c_custkey:Int64, c_name:Utf8] - SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N] - Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N] - Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N] - Filter: orders.o_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] - TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + Filter: customer.c_custkey = () [c_custkey:Int64, c_name:Utf8] + Subquery: [max(orders.o_custkey):Int64;N] + Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N] + Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N] + Filter: orders.o_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8] " ) } @@ -1032,14 +1044,12 @@ mod tests { plan, @r" Projection: customer.c_custkey [c_custkey:Int64] - Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8] - Filter: customer.c_custkey < __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N] - Left Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N] - TableScan: customer [c_custkey:Int64, c_name:Utf8] - SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N] - Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N] - Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N] - TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + Filter: customer.c_custkey < () [c_custkey:Int64, c_name:Utf8] + Subquery: [max(orders.o_custkey):Int64;N] + Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N] + Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8] " ) } @@ -1062,14 +1072,12 @@ mod tests { plan, @r" Projection: customer.c_custkey [c_custkey:Int64] - Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8] - Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N] - Left Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N] - TableScan: customer [c_custkey:Int64, c_name:Utf8] - SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N] - Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N] - Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N] - TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + Filter: customer.c_custkey = () [c_custkey:Int64, c_name:Utf8] + Subquery: [max(orders.o_custkey):Int64;N] + Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N] + Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8] " ) } @@ -1161,19 +1169,16 @@ mod tests { plan, @r" Projection: customer.c_custkey [c_custkey:Int64] - Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8] - Filter: customer.c_custkey BETWEEN __scalar_sq_1.min(orders.o_custkey) AND __scalar_sq_2.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, max(orders.o_custkey):Int64;N] - Left Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, max(orders.o_custkey):Int64;N] - Left Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N] - TableScan: customer [c_custkey:Int64, c_name:Utf8] - SubqueryAlias: __scalar_sq_1 [min(orders.o_custkey):Int64;N] - Projection: min(orders.o_custkey) [min(orders.o_custkey):Int64;N] - Aggregate: groupBy=[[]], aggr=[[min(orders.o_custkey)]] [min(orders.o_custkey):Int64;N] - TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] - SubqueryAlias: __scalar_sq_2 [max(orders.o_custkey):Int64;N] - Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N] - Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N] - TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + Filter: customer.c_custkey BETWEEN () AND () [c_custkey:Int64, c_name:Utf8] + Subquery: [min(orders.o_custkey):Int64;N] + Projection: min(orders.o_custkey) [min(orders.o_custkey):Int64;N] + Aggregate: groupBy=[[]], aggr=[[min(orders.o_custkey)]] [min(orders.o_custkey):Int64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + Subquery: [max(orders.o_custkey):Int64;N] + Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N] + Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8] " ) } diff --git a/datafusion/optimizer/tests/optimizer_integration.rs b/datafusion/optimizer/tests/optimizer_integration.rs index ad7427acf92ac..b1ba313e1deb1 100644 --- a/datafusion/optimizer/tests/optimizer_integration.rs +++ b/datafusion/optimizer/tests/optimizer_integration.rs @@ -136,15 +136,13 @@ fn subquery_filter_with_cast() -> Result<()> { assert_snapshot!( format!("{plan}"), @r#" - Projection: test.col_int32 - Inner Join: Filter: CAST(test.col_int32 AS Float64) > __scalar_sq_1.avg(test.col_int32) - TableScan: test projection=[col_int32] - SubqueryAlias: __scalar_sq_1 - Aggregate: groupBy=[[]], aggr=[[avg(CAST(test.col_int32 AS Float64))]] - Projection: test.col_int32 - Filter: __common_expr_4 >= Date32("2002-05-08") AND __common_expr_4 <= Date32("2002-05-13") - Projection: CAST(test.col_utf8 AS Date32) AS __common_expr_4, test.col_int32 - TableScan: test projection=[col_int32, col_utf8] + Filter: CAST(test.col_int32 AS Float64) > () + Subquery: + Aggregate: groupBy=[[]], aggr=[[avg(CAST(test.col_int32 AS Float64))]] + Projection: test.col_int32 + Filter: CAST(test.col_utf8 AS Date32) >= Date32("2002-05-08") AND CAST(test.col_utf8 AS Date32) <= Date32("2002-05-13") + TableScan: test projection=[col_int32, col_utf8] + TableScan: test projection=[col_int32] "# ); Ok(()) diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index bedd348dab92f..9c567ce862149 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -40,6 +40,7 @@ mod physical_expr; pub mod planner; pub mod projection; mod scalar_function; +pub mod scalar_subquery; pub mod simplifier; pub mod statistics; pub mod utils; diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index 0880f1a139d14..5d4c5d4ce1b69 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -18,6 +18,7 @@ use std::sync::Arc; use crate::ScalarFunctionExpr; +use crate::scalar_subquery::ScalarSubqueryExpr; use crate::{ PhysicalExpr, expressions::{self, Column, Literal, binary, like, similar_to}, @@ -396,6 +397,35 @@ pub fn create_physical_expr( expressions::in_list(value_expr, list_exprs, negated, input_schema) } }, + Expr::ScalarSubquery(sq) => { + match execution_props.subquery_indexes.get(sq) { + Some(&index) => { + let schema = sq.subquery.schema(); + if schema.fields().len() != 1 { + return plan_err!( + "Scalar subquery must return exactly one column, got {}", + schema.fields().len() + ); + } + let dt = schema.field(0).data_type().clone(); + let nullable = schema.field(0).is_nullable(); + Ok(Arc::new(ScalarSubqueryExpr::new( + dt, + nullable, + index, + execution_props.subquery_results.clone(), + ))) + } + None => { + // Not found: either a correlated subquery that wasn't + // rewritten to a join, or an uncorrelated one that wasn't + // registered by the physical planner. + not_impl_err!( + "Physical plan does not support logical expression {e:?}" + ) + } + } + } Expr::Placeholder(Placeholder { id, .. }) => { exec_err!("Placeholder '{id}' was not provided a value for execution.") } diff --git a/datafusion/physical-expr/src/scalar_subquery.rs b/datafusion/physical-expr/src/scalar_subquery.rs new file mode 100644 index 0000000000000..ea00847151e66 --- /dev/null +++ b/datafusion/physical-expr/src/scalar_subquery.rs @@ -0,0 +1,240 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Physical expression for uncorrelated scalar subqueries. + +use std::fmt; +use std::hash::Hash; +use std::sync::Arc; + +use arrow::datatypes::{DataType, Field, FieldRef, Schema}; +use arrow::record_batch::RecordBatch; +use datafusion_common::{Result, internal_datafusion_err}; +use datafusion_expr::execution_props::{ScalarSubqueryResults, SubqueryIndex}; +use datafusion_expr_common::columnar_value::ColumnarValue; +use datafusion_expr_common::sort_properties::{ExprProperties, SortProperties}; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; + +/// A physical expression whose value is provided by a scalar subquery. +/// +/// Subquery execution is handled by `ScalarSubqueryExec`, which stores the +/// result in a shared [`ScalarSubqueryResults`] container. This expression +/// simply reads from that container at the appropriate index. +#[derive(Debug)] +pub struct ScalarSubqueryExpr { + data_type: DataType, + nullable: bool, + /// Index of this subquery in the shared results container. + index: SubqueryIndex, + /// Shared results container populated by `ScalarSubqueryExec`. + results: ScalarSubqueryResults, +} + +impl ScalarSubqueryExpr { + pub fn new( + data_type: DataType, + nullable: bool, + index: SubqueryIndex, + results: ScalarSubqueryResults, + ) -> Self { + Self { + data_type, + nullable, + index, + results, + } + } + + pub fn data_type(&self) -> &DataType { + &self.data_type + } + + pub fn nullable(&self) -> bool { + self.nullable + } + + /// Returns the index of this subquery in the shared results container. + pub fn index(&self) -> SubqueryIndex { + self.index + } + + pub fn results(&self) -> &ScalarSubqueryResults { + &self.results + } +} + +impl fmt::Display for ScalarSubqueryExpr { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.results.get(self.index) { + Some(v) => write!(f, "scalar_subquery({v})"), + None => write!(f, "scalar_subquery()"), + } + } +} + +// Two ScalarSubqueryExprs are considered the "same" if they refer to the +// same underlying shared results container and the same index within it. +impl Hash for ScalarSubqueryExpr { + fn hash(&self, state: &mut H) { + self.results.hash(state); + self.index.hash(state); + } +} + +impl PartialEq for ScalarSubqueryExpr { + fn eq(&self, other: &Self) -> bool { + self.results == other.results && self.index == other.index + } +} + +impl Eq for ScalarSubqueryExpr {} + +impl PhysicalExpr for ScalarSubqueryExpr { + fn return_field(&self, _input_schema: &Schema) -> Result { + Ok(Arc::new(Field::new( + "scalar_subquery", + self.data_type.clone(), + self.nullable, + ))) + } + + fn evaluate(&self, _batch: &RecordBatch) -> Result { + let value = self.results.get(self.index).ok_or_else(|| { + internal_datafusion_err!( + "ScalarSubqueryExpr evaluated before the subquery was executed" + ) + })?; + Ok(ColumnarValue::Scalar(value)) + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + _children: Vec>, + ) -> Result> { + Ok(self) + } + + fn get_properties(&self, _children: &[ExprProperties]) -> Result { + Ok(ExprProperties::new_unknown().with_order(SortProperties::Singleton)) + } + + fn fmt_sql(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "(scalar subquery)") + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use arrow::array::Int32Array; + use arrow::datatypes::Field; + use datafusion_common::ScalarValue; + + fn make_results(values: Vec>) -> ScalarSubqueryResults { + let results = ScalarSubqueryResults::new(values.len()); + for (index, value) in values.into_iter().enumerate() { + if let Some(value) = value { + results.set(SubqueryIndex::new(index), value).unwrap(); + } + } + results + } + + #[test] + fn test_evaluate_with_value() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); + let a = Int32Array::from(vec![1, 2, 3]); + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?; + + let results = make_results(vec![Some(ScalarValue::Int32(Some(42)))]); + let expr = ScalarSubqueryExpr::new( + DataType::Int32, + false, + SubqueryIndex::new(0), + results, + ); + + let result = expr.evaluate(&batch)?; + match result { + ColumnarValue::Scalar(ScalarValue::Int32(Some(42))) => {} + other => panic!("Expected Scalar(Int32(42)), got {other:?}"), + } + Ok(()) + } + + #[test] + fn test_evaluate_before_populated() { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); + let a = Int32Array::from(vec![1]); + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)]).unwrap(); + + let results = ScalarSubqueryResults::new(1); + let expr = ScalarSubqueryExpr::new( + DataType::Int32, + false, + SubqueryIndex::new(0), + results, + ); + + let result = expr.evaluate(&batch); + assert!(result.is_err()); + } + + #[test] + fn test_identity_equality() { + let results = make_results(vec![None, None]); + + let e1a = ScalarSubqueryExpr::new( + DataType::Int32, + false, + SubqueryIndex::new(0), + results.clone(), + ); + let e1b = ScalarSubqueryExpr::new( + DataType::Int32, + false, + SubqueryIndex::new(0), + results.clone(), + ); + let e2 = ScalarSubqueryExpr::new( + DataType::Int32, + false, + SubqueryIndex::new(1), + results.clone(), + ); + + // Same container + same index → equal + assert_eq!(e1a, e1b); + // Same container, different index → not equal + assert_ne!(e1a, e2); + + // Different container, same index → not equal + let other_results = make_results(vec![None]); + let e3 = ScalarSubqueryExpr::new( + DataType::Int32, + false, + SubqueryIndex::new(0), + other_results, + ); + assert_ne!(e1a, e3); + } +} diff --git a/datafusion/physical-plan/src/lib.rs b/datafusion/physical-plan/src/lib.rs index 54fc97c154206..3005e975424b4 100644 --- a/datafusion/physical-plan/src/lib.rs +++ b/datafusion/physical-plan/src/lib.rs @@ -85,6 +85,7 @@ pub mod placeholder_row; pub mod projection; pub mod recursive_query; pub mod repartition; +pub mod scalar_subquery; pub mod sort_pushdown; pub mod sorts; pub mod spill; diff --git a/datafusion/physical-plan/src/scalar_subquery.rs b/datafusion/physical-plan/src/scalar_subquery.rs new file mode 100644 index 0000000000000..b6ad7f91f097d --- /dev/null +++ b/datafusion/physical-plan/src/scalar_subquery.rs @@ -0,0 +1,574 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Execution plan for uncorrelated scalar subqueries. +//! +//! [`ScalarSubqueryExec`] wraps a main input plan and a set of subquery plans. +//! At execution time, it runs each subquery exactly once, extracts the scalar +//! result, and populates a shared [`ScalarSubqueryResults`] container that +//! [`ScalarSubqueryExpr`] instances hold directly and read from by index. +//! +//! [`ScalarSubqueryExpr`]: datafusion_physical_expr::scalar_subquery::ScalarSubqueryExpr + +use std::fmt; +use std::sync::Arc; + +use datafusion_common::tree_node::TreeNodeRecursion; +use datafusion_common::{Result, ScalarValue, Statistics, exec_err, internal_err}; +use datafusion_execution::TaskContext; +use datafusion_expr::execution_props::{ScalarSubqueryResults, SubqueryIndex}; +use datafusion_physical_expr::PhysicalExpr; + +use crate::execution_plan::{CardinalityEffect, ExecutionPlan, PlanProperties}; +use crate::joins::utils::{OnceAsync, OnceFut}; +use crate::stream::RecordBatchStreamAdapter; +use crate::{DisplayAs, DisplayFormatType, SendableRecordBatchStream}; + +use futures::StreamExt; +use futures::TryStreamExt; + +/// Links a scalar subquery's execution plan to its index in the shared results +/// container. The [`ScalarSubqueryExec`] that owns these links populates +/// `results[index]` at execution time, and [`ScalarSubqueryExpr`] instances +/// with the same index read from it. +/// +/// [`ScalarSubqueryExpr`]: datafusion_physical_expr::scalar_subquery::ScalarSubqueryExpr +#[derive(Debug, Clone)] +pub struct ScalarSubqueryLink { + /// The physical plan for the subquery. + pub plan: Arc, + /// Index into the shared results container. + pub index: SubqueryIndex, +} + +/// Manages execution of uncorrelated scalar subqueries for a single plan +/// level. +/// +/// From a query-results perspective, this node is a pass-through: it yields +/// the same batches as its main input and exists only to populate scalar +/// subquery results as a side effect before those batches are produced. +/// +/// The first child node is the **main input plan**, whose batches are passed +/// through unchanged. The remaining children are **subquery plans**, each of +/// which must produce exactly zero or one row. Before any batches from the main +/// input are yielded, all subquery plans are executed and their scalar results +/// are stored in a shared [`ScalarSubqueryResults`] container owned by this +/// node. [`ScalarSubqueryExpr`] nodes embedded in the main input's expressions +/// hold the same container and read from it by index. +/// +/// All subqueries are evaluated eagerly when the first output partition is +/// requested, before any rows from the main input are produced. +/// +/// TODO: Consider overlapping computation of the subqueries with evaluating the +/// main query. +/// +/// [`ScalarSubqueryExpr`]: datafusion_physical_expr::scalar_subquery::ScalarSubqueryExpr +#[derive(Debug)] +pub struct ScalarSubqueryExec { + /// The main input plan whose output is passed through. + input: Arc, + /// Subquery plans and their result indexes. + subqueries: Vec, + /// Shared one-time async computation of subquery results. + subquery_future: Arc>, + /// Shared results container; the corresponding `ScalarSubqueryExpr` + /// nodes in the input plan hold the same underlying container. + results: ScalarSubqueryResults, + /// Cached plan properties (copied from input). + cache: Arc, +} + +impl ScalarSubqueryExec { + pub fn new( + input: Arc, + subqueries: Vec, + results: ScalarSubqueryResults, + ) -> Self { + let cache = Arc::clone(input.properties()); + Self { + input, + subqueries, + subquery_future: Arc::default(), + results, + cache, + } + } + + pub fn input(&self) -> &Arc { + &self.input + } + + pub fn subqueries(&self) -> &[ScalarSubqueryLink] { + &self.subqueries + } + + pub fn results(&self) -> &ScalarSubqueryResults { + &self.results + } + + /// Returns a per-child bool vec that is `true` for the main input + /// (child 0) and `false` for every subquery child. + fn true_for_input_only(&self) -> Vec { + std::iter::once(true) + .chain(std::iter::repeat_n(false, self.subqueries.len())) + .collect() + } +} + +impl DisplayAs for ScalarSubqueryExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!( + f, + "ScalarSubqueryExec: subqueries={}", + self.subqueries.len() + ) + } + DisplayFormatType::TreeRender => { + write!(f, "") + } + } + } +} + +impl ExecutionPlan for ScalarSubqueryExec { + fn name(&self) -> &'static str { + "ScalarSubqueryExec" + } + + fn properties(&self) -> &Arc { + &self.cache + } + + fn children(&self) -> Vec<&Arc> { + let mut children = vec![&self.input]; + for sq in &self.subqueries { + children.push(&sq.plan); + } + children + } + + fn with_new_children( + self: Arc, + mut children: Vec>, + ) -> Result> { + // First child is the main input, the rest are subquery plans. + let input = children.remove(0); + let subqueries = self + .subqueries + .iter() + .zip(children) + .map(|(sq, new_plan)| ScalarSubqueryLink { + plan: new_plan, + index: sq.index, + }) + .collect(); + Ok(Arc::new(ScalarSubqueryExec::new( + input, + subqueries, + self.results.clone(), + ))) + } + + fn reset_state(self: Arc) -> Result> { + self.results.clear(); + Ok(Arc::new(ScalarSubqueryExec { + input: Arc::clone(&self.input), + subqueries: self.subqueries.clone(), + subquery_future: Arc::default(), + results: self.results.clone(), + cache: Arc::clone(&self.cache), + })) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + let subqueries = self.subqueries.clone(); + let results = self.results.clone(); + let subquery_ctx = Arc::clone(&context); + let mut subquery_future = self.subquery_future.try_once(move || { + Ok(async move { execute_subqueries(subqueries, results, subquery_ctx).await }) + })?; + let input = Arc::clone(&self.input); + let schema = self.schema(); + + Ok(Box::pin(RecordBatchStreamAdapter::new( + schema, + futures::stream::once(async move { + // Execute all subqueries exactly once, even when multiple + // partitions call execute() concurrently. + wait_for_subqueries(&mut subquery_future).await?; + + // Now that the subqueries have finished execution, we can + // safely execute the main input + input.execute(partition, context) + }) + .try_flatten(), + ))) + } + + fn apply_expressions( + &self, + _f: &mut dyn FnMut(&dyn PhysicalExpr) -> Result, + ) -> Result { + Ok(TreeNodeRecursion::Continue) + } + + fn maintains_input_order(&self) -> Vec { + // Only the main input (first child); subquery children don't contribute + // to ordering. + self.true_for_input_only() + } + + fn benefits_from_input_partitioning(&self) -> Vec { + // Only the main input; subquery children produce at most one row, so + // repartitioning them has no benefit. + self.true_for_input_only() + } + + fn partition_statistics(&self, partition: Option) -> Result> { + self.input.partition_statistics(partition) + } + + fn cardinality_effect(&self) -> CardinalityEffect { + CardinalityEffect::Equal + } +} + +/// Wait for the subquery execution future to complete. +async fn wait_for_subqueries(fut: &mut OnceFut<()>) -> Result<()> { + std::future::poll_fn(|cx| fut.get_shared(cx)).await?; + Ok(()) +} + +async fn execute_subqueries( + subqueries: Vec, + results: ScalarSubqueryResults, + context: Arc, +) -> Result<()> { + // Evaluate subqueries in parallel; wait for them all to finish evaluation + // before returning. + let futures = subqueries.iter().map(|sq| { + let plan = Arc::clone(&sq.plan); + let ctx = Arc::clone(&context); + let results = results.clone(); + let index = sq.index; + async move { + let value = execute_scalar_subquery(plan, ctx).await?; + results.set(index, value)?; + Ok(()) as Result<()> + } + }); + futures::future::try_join_all(futures).await?; + Ok(()) +} + +/// Execute a single subquery plan and extract the scalar value. +/// Returns NULL for 0 rows, the scalar value for exactly 1 row, +/// or an error for >1 rows. +async fn execute_scalar_subquery( + plan: Arc, + context: Arc, +) -> Result { + let schema = plan.schema(); + if schema.fields().len() != 1 { + // Should be enforced by the physical planner. + return internal_err!( + "Scalar subquery must return exactly one column, got {}", + schema.fields().len() + ); + } + + let mut stream = crate::execute_stream(plan, context)?; + let mut result: Option = None; + + while let Some(batch) = stream.next().await.transpose()? { + if batch.num_rows() == 0 { + continue; + } + if result.is_some() || batch.num_rows() > 1 { + return exec_err!("Scalar subquery returned more than one row"); + } + result = Some(ScalarValue::try_from_array(batch.column(0), 0)?); + } + + // 0 rows → typed NULL per SQL semantics + match result { + Some(v) => Ok(v), + None => ScalarValue::try_from(schema.field(0).data_type()), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::test::{self, TestMemoryExec}; + use crate::{ + execution_plan::reset_plan_states, + projection::{ProjectionExec, ProjectionExpr}, + }; + + use std::sync::atomic::{AtomicUsize, Ordering}; + + use crate::test::exec::ErrorExec; + use arrow::array::{Int32Array, Int64Array}; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow::record_batch::RecordBatch; + use datafusion_physical_expr::scalar_subquery::ScalarSubqueryExpr; + + enum ExpectedSubqueryResult { + Value(ScalarValue), + Error(&'static str), + } + + #[derive(Debug)] + struct CountingExec { + inner: Arc, + execute_calls: Arc, + } + + impl CountingExec { + fn new(inner: Arc, execute_calls: Arc) -> Self { + Self { + inner, + execute_calls, + } + } + } + + impl DisplayAs for CountingExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "CountingExec") + } + DisplayFormatType::TreeRender => write!(f, ""), + } + } + } + + impl ExecutionPlan for CountingExec { + fn name(&self) -> &'static str { + "CountingExec" + } + + fn properties(&self) -> &Arc { + self.inner.properties() + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.inner] + } + + fn with_new_children( + self: Arc, + mut children: Vec>, + ) -> Result> { + Ok(Arc::new(Self::new( + children.remove(0), + Arc::clone(&self.execute_calls), + ))) + } + + fn apply_expressions( + &self, + _f: &mut dyn FnMut(&dyn PhysicalExpr) -> Result, + ) -> Result { + Ok(TreeNodeRecursion::Continue) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + self.execute_calls.fetch_add(1, Ordering::SeqCst); + self.inner.execute(partition, context) + } + } + + fn make_subquery_plan(batches: Vec) -> Arc { + let schema = batches[0].schema(); + TestMemoryExec::try_new_exec(&[batches], schema, None).unwrap() + } + + fn int32_batch(values: Vec) -> RecordBatch { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + RecordBatch::try_new(schema, vec![Arc::new(Int32Array::from(values))]).unwrap() + } + + fn empty_int64_batch() -> RecordBatch { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, true)])); + RecordBatch::try_new(schema, vec![Arc::new(Int64Array::from(vec![] as Vec))]) + .unwrap() + } + + fn placeholder_input() -> Arc { + Arc::new(crate::placeholder_row::PlaceholderRowExec::new( + test::aggr_test_schema(), + )) + } + + fn single_subquery_exec( + input: Arc, + subquery_plan: Arc, + results: ScalarSubqueryResults, + ) -> ScalarSubqueryExec { + ScalarSubqueryExec::new( + input, + vec![ScalarSubqueryLink { + plan: subquery_plan, + index: SubqueryIndex::new(0), + }], + results, + ) + } + + fn scalar_subquery_projection_input( + results: ScalarSubqueryResults, + ) -> Result> { + Ok(Arc::new(ProjectionExec::try_new( + vec![ProjectionExpr { + expr: Arc::new(ScalarSubqueryExpr::new( + DataType::Int32, + false, + SubqueryIndex::new(0), + results, + )), + alias: "sq".to_string(), + }], + placeholder_input(), + )?)) + } + + fn extract_single_int32_value(batches: &[RecordBatch]) -> i32 { + assert_eq!(batches.len(), 1); + let values = batches[0] + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(values.len(), 1); + values.value(0) + } + + #[tokio::test] + async fn test_execute_scalar_subquery_row_count_semantics() -> Result<()> { + for (name, plan, expected) in [ + ( + "single_row", + make_subquery_plan(vec![int32_batch(vec![42])]), + ExpectedSubqueryResult::Value(ScalarValue::Int32(Some(42))), + ), + ( + "zero_rows", + make_subquery_plan(vec![empty_int64_batch()]), + ExpectedSubqueryResult::Value(ScalarValue::Int64(None)), + ), + ( + "multiple_rows", + make_subquery_plan(vec![int32_batch(vec![1, 2, 3])]), + ExpectedSubqueryResult::Error("more than one row"), + ), + ] { + let actual = + execute_scalar_subquery(plan, Arc::new(TaskContext::default())).await; + match expected { + ExpectedSubqueryResult::Value(expected) => { + assert_eq!(actual?, expected, "{name}"); + } + ExpectedSubqueryResult::Error(expected) => { + let err = actual.expect_err(name); + assert!( + err.to_string().contains(expected), + "{name}: expected error containing '{expected}', got {err}" + ); + } + } + } + + Ok(()) + } + + #[tokio::test] + async fn test_failed_subquery_is_not_retried() -> Result<()> { + let execute_calls = Arc::new(AtomicUsize::new(0)); + let subquery_plan = Arc::new(CountingExec::new( + Arc::new(ErrorExec::new()), + Arc::clone(&execute_calls), + )); + let exec = single_subquery_exec( + placeholder_input(), + subquery_plan, + ScalarSubqueryResults::new(1), + ); + + let ctx = Arc::new(TaskContext::default()); + let stream = exec.execute(0, Arc::clone(&ctx))?; + assert!(crate::common::collect(stream).await.is_err()); + + let stream = exec.execute(0, ctx)?; + assert!(crate::common::collect(stream).await.is_err()); + + assert_eq!(execute_calls.load(Ordering::SeqCst), 1); + Ok(()) + } + + #[tokio::test] + async fn test_reset_state_clears_results_and_reexecutes_subqueries() -> Result<()> { + let execute_calls = Arc::new(AtomicUsize::new(0)); + let results = ScalarSubqueryResults::new(1); + let subquery_plan = Arc::new(CountingExec::new( + make_subquery_plan(vec![int32_batch(vec![42])]), + Arc::clone(&execute_calls), + )); + let exec: Arc = Arc::new(single_subquery_exec( + scalar_subquery_projection_input(results.clone())?, + subquery_plan, + results.clone(), + )); + + let batches = + crate::common::collect(exec.execute(0, Arc::new(TaskContext::default()))?) + .await?; + assert_eq!(extract_single_int32_value(&batches), 42); + assert_eq!( + results.get(SubqueryIndex::new(0)), + Some(ScalarValue::Int32(Some(42))) + ); + + let reset_exec = reset_plan_states(Arc::clone(&exec))?; + assert_eq!(results.get(SubqueryIndex::new(0)), None); + + let reset_batches = crate::common::collect( + reset_exec.execute(0, Arc::new(TaskContext::default()))?, + ) + .await?; + assert_eq!(extract_single_int32_value(&reset_batches), 42); + assert_eq!( + results.get(SubqueryIndex::new(0)), + Some(ScalarValue::Int32(Some(42))) + ); + assert_eq!(execute_calls.load(Ordering::SeqCst), 2); + + Ok(()) + } +} diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index c61226fb526f6..29b71929f1e84 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -426,6 +426,8 @@ message LogicalExprNode { Unnest unnest = 35; + // Subquery expressions + ScalarSubqueryExprNode scalar_subquery_expr = 36; } } @@ -433,6 +435,15 @@ message Wildcard { TableReference qualifier = 1; } +message SubqueryNode { + LogicalPlanNode subquery = 1; + repeated LogicalExprNode outer_ref_columns = 2; +} + +message ScalarSubqueryExprNode { + SubqueryNode subquery = 1; +} + message PlaceholderNode { string id = 1; // We serialize the data type, metadata, and nullability separately to maintain @@ -775,6 +786,7 @@ message PhysicalPlanNode { AsyncFuncExecNode async_func = 36; BufferExecNode buffer = 37; ArrowScanExecNode arrow_scan = 38; + ScalarSubqueryExecNode scalar_subquery = 39; } } @@ -920,6 +932,8 @@ message PhysicalExprNode { UnknownColumn unknown_column = 20; PhysicalHashExprNode hash_expr = 21; + + PhysicalScalarSubqueryExprNode scalar_subquery = 22; } } @@ -1477,4 +1491,15 @@ message AsyncFuncExecNode { message BufferExecNode { PhysicalPlanNode input = 1; uint64 capacity = 2; +} + +message ScalarSubqueryExecNode { + PhysicalPlanNode input = 1; + repeated PhysicalPlanNode subqueries = 2; +} + +message PhysicalScalarSubqueryExprNode { + datafusion_common.ArrowType data_type = 1; + bool nullable = 2; + uint32 index = 3; } \ No newline at end of file diff --git a/datafusion/proto/src/bytes/mod.rs b/datafusion/proto/src/bytes/mod.rs index 84b15ea9a8920..2b7d7ed8e849b 100644 --- a/datafusion/proto/src/bytes/mod.rs +++ b/datafusion/proto/src/bytes/mod.rs @@ -22,28 +22,20 @@ use crate::logical_plan::{ }; use crate::physical_plan::{ DefaultPhysicalExtensionCodec, DefaultPhysicalProtoConverter, PhysicalExtensionCodec, - PhysicalProtoConverterExtension, + PhysicalPlanDecodeContext, PhysicalProtoConverterExtension, }; use crate::protobuf; use datafusion_common::{Result, plan_datafusion_err}; use datafusion_execution::TaskContext; -use datafusion_expr::{ - AggregateUDF, Expr, LogicalPlan, Volatility, WindowUDF, create_udaf, create_udf, - create_udwf, -}; +use datafusion_expr::{Expr, LogicalPlan}; use prost::{ Message, bytes::{Bytes, BytesMut}, }; use std::sync::Arc; -// Reexport Bytes which appears in the API -use datafusion_execution::registry::FunctionRegistry; -use datafusion_expr::planner::ExprPlanner; use datafusion_physical_plan::ExecutionPlan; -mod registry; - /// Encodes something (such as [`Expr`]) to/from a stream of /// bytes. /// @@ -65,26 +57,21 @@ pub trait Serializeable: Sized { /// Convert `self` to an opaque byte stream fn to_bytes(&self) -> Result; - /// Convert `bytes` (the output of [`to_bytes`]) back into an - /// object. This will error if the serialized bytes contain any - /// user defined functions, in which case use - /// [`from_bytes_with_registry`] + /// Convert `bytes` (the output of [`to_bytes`]) back into an object. This + /// will error if the serialized bytes contain any user defined functions, + /// in which case use [`from_bytes_with_ctx`] /// /// [`to_bytes`]: Self::to_bytes - /// [`from_bytes_with_registry`]: Self::from_bytes_with_registry + /// [`from_bytes_with_ctx`]: Self::from_bytes_with_ctx fn from_bytes(bytes: &[u8]) -> Result { - Self::from_bytes_with_registry(bytes, ®istry::NoRegistry {}) + Self::from_bytes_with_ctx(bytes, &TaskContext::default()) } - /// Convert `bytes` (the output of [`to_bytes`]) back into an - /// object resolving user defined functions with the specified - /// `registry` + /// Convert `bytes` (the output of [`to_bytes`]) back into an object, + /// resolving user defined functions with the specified `ctx` /// /// [`to_bytes`]: Self::to_bytes - fn from_bytes_with_registry( - bytes: &[u8], - registry: &dyn FunctionRegistry, - ) -> Result; + fn from_bytes_with_ctx(bytes: &[u8], ctx: &TaskContext) -> Result; } impl Serializeable for Expr { @@ -100,100 +87,22 @@ impl Serializeable for Expr { let bytes: Bytes = buffer.into(); - // the produced byte stream may lead to "recursion limit" errors, see + // The produced byte stream may lead to "recursion limit" errors, see // https://github.com/apache/datafusion/issues/3968 - // Until the underlying prost issue ( https://github.com/tokio-rs/prost/issues/736 ) is fixed, we try to - // deserialize the data here and check for errors. - // - // Need to provide some placeholder registry because the stream may contain UDFs - struct PlaceHolderRegistry; - - impl FunctionRegistry for PlaceHolderRegistry { - fn udfs(&self) -> std::collections::HashSet { - std::collections::HashSet::default() - } - - fn udf(&self, name: &str) -> Result> { - Ok(Arc::new(create_udf( - name, - vec![], - arrow::datatypes::DataType::Null, - Volatility::Immutable, - Arc::new(|_| unimplemented!()), - ))) - } - - fn udaf(&self, name: &str) -> Result> { - Ok(Arc::new(create_udaf( - name, - vec![arrow::datatypes::DataType::Null], - Arc::new(arrow::datatypes::DataType::Null), - Volatility::Immutable, - Arc::new(|_| unimplemented!()), - Arc::new(vec![]), - ))) - } - - fn udwf(&self, name: &str) -> Result> { - Ok(Arc::new(create_udwf( - name, - arrow::datatypes::DataType::Null, - Arc::new(arrow::datatypes::DataType::Null), - Volatility::Immutable, - Arc::new(|| unimplemented!()), - ))) - } - fn register_udaf( - &mut self, - _udaf: Arc, - ) -> Result>> { - datafusion_common::internal_err!( - "register_udaf called in Placeholder Registry!" - ) - } - fn register_udf( - &mut self, - _udf: Arc, - ) -> Result>> { - datafusion_common::internal_err!( - "register_udf called in Placeholder Registry!" - ) - } - fn register_udwf( - &mut self, - _udaf: Arc, - ) -> Result>> { - datafusion_common::internal_err!( - "register_udwf called in Placeholder Registry!" - ) - } - - fn expr_planners(&self) -> Vec> { - vec![] - } - - fn udafs(&self) -> std::collections::HashSet { - std::collections::HashSet::default() - } - - fn udwfs(&self) -> std::collections::HashSet { - std::collections::HashSet::default() - } - } - Expr::from_bytes_with_registry(&bytes, &PlaceHolderRegistry)?; + // Until the underlying prost issue ( https://github.com/tokio-rs/prost/issues/736 ) + // is fixed, verify the bytes can be decoded without hitting that limit. + protobuf::LogicalExprNode::decode(bytes.as_ref()) + .map_err(|e| plan_datafusion_err!("Error decoding expr as protobuf: {e}"))?; Ok(bytes) } - fn from_bytes_with_registry( - bytes: &[u8], - registry: &dyn FunctionRegistry, - ) -> Result { + fn from_bytes_with_ctx(bytes: &[u8], ctx: &TaskContext) -> Result { let protobuf = protobuf::LogicalExprNode::decode(bytes) .map_err(|e| plan_datafusion_err!("Error decoding expr as protobuf: {e}"))?; let extension_codec = DefaultLogicalExtensionCodec {}; - logical_plan::from_proto::parse_expr(&protobuf, registry, &extension_codec) + logical_plan::from_proto::parse_expr(&protobuf, ctx, &extension_codec) .map_err(|e| plan_datafusion_err!("Error parsing protobuf into Expr: {e}")) } } @@ -327,7 +236,8 @@ pub fn physical_plan_from_json( .map_err(|e| plan_datafusion_err!("Error serializing plan: {e}"))?; let extension_codec = DefaultPhysicalExtensionCodec {}; let proto_converter = DefaultPhysicalProtoConverter {}; - proto_converter.proto_to_execution_plan(ctx, &extension_codec, &back) + let decode_ctx = PhysicalPlanDecodeContext::new(ctx, &extension_codec); + proto_converter.proto_to_execution_plan(&back, &decode_ctx) } /// Deserialize a PhysicalPlan from bytes @@ -369,5 +279,6 @@ pub fn physical_plan_from_bytes_with_proto_converter( ) -> Result> { let protobuf = protobuf::PhysicalPlanNode::decode(bytes) .map_err(|e| plan_datafusion_err!("Error decoding expr as protobuf: {e}"))?; - proto_converter.proto_to_execution_plan(ctx, extension_codec, &protobuf) + let decode_ctx = PhysicalPlanDecodeContext::new(ctx, extension_codec); + proto_converter.proto_to_execution_plan(&protobuf, &decode_ctx) } diff --git a/datafusion/proto/src/bytes/registry.rs b/datafusion/proto/src/bytes/registry.rs deleted file mode 100644 index a3f74787e2b50..0000000000000 --- a/datafusion/proto/src/bytes/registry.rs +++ /dev/null @@ -1,85 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::{collections::HashSet, sync::Arc}; - -use datafusion_common::Result; -use datafusion_common::plan_err; -use datafusion_execution::registry::FunctionRegistry; -use datafusion_expr::planner::ExprPlanner; -use datafusion_expr::{AggregateUDF, ScalarUDF, WindowUDF}; - -/// A default [`FunctionRegistry`] registry that does not resolve any -/// user defined functions -pub(crate) struct NoRegistry {} - -impl FunctionRegistry for NoRegistry { - fn udfs(&self) -> HashSet { - HashSet::new() - } - - fn udf(&self, name: &str) -> Result> { - plan_err!( - "No function registry provided to deserialize, so can not deserialize User Defined Function '{name}'" - ) - } - - fn udaf(&self, name: &str) -> Result> { - plan_err!( - "No function registry provided to deserialize, so can not deserialize User Defined Aggregate Function '{name}'" - ) - } - - fn udwf(&self, name: &str) -> Result> { - plan_err!( - "No function registry provided to deserialize, so can not deserialize User Defined Window Function '{name}'" - ) - } - fn register_udaf( - &mut self, - udaf: Arc, - ) -> Result>> { - plan_err!( - "No function registry provided to deserialize, so can not register User Defined Aggregate Function '{}'", - udaf.inner().name() - ) - } - fn register_udf(&mut self, udf: Arc) -> Result>> { - plan_err!( - "No function registry provided to deserialize, so can not deserialize User Defined Function '{}'", - udf.inner().name() - ) - } - fn register_udwf(&mut self, udwf: Arc) -> Result>> { - plan_err!( - "No function registry provided to deserialize, so can not deserialize User Defined Window Function '{}'", - udwf.inner().name() - ) - } - - fn expr_planners(&self) -> Vec> { - vec![] - } - - fn udafs(&self) -> HashSet { - HashSet::new() - } - - fn udwfs(&self) -> HashSet { - HashSet::new() - } -} diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 82bcdac898204..d84fdae345f5b 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -12193,6 +12193,9 @@ impl serde::Serialize for LogicalExprNode { logical_expr_node::ExprType::Unnest(v) => { struct_ser.serialize_field("unnest", v)?; } + logical_expr_node::ExprType::ScalarSubqueryExpr(v) => { + struct_ser.serialize_field("scalarSubqueryExpr", v)?; + } } } struct_ser.end() @@ -12254,6 +12257,8 @@ impl<'de> serde::Deserialize<'de> for LogicalExprNode { "similarTo", "placeholder", "unnest", + "scalar_subquery_expr", + "scalarSubqueryExpr", ]; #[allow(clippy::enum_variant_names)] @@ -12289,6 +12294,7 @@ impl<'de> serde::Deserialize<'de> for LogicalExprNode { SimilarTo, Placeholder, Unnest, + ScalarSubqueryExpr, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -12341,6 +12347,7 @@ impl<'de> serde::Deserialize<'de> for LogicalExprNode { "similarTo" | "similar_to" => Ok(GeneratedField::SimilarTo), "placeholder" => Ok(GeneratedField::Placeholder), "unnest" => Ok(GeneratedField::Unnest), + "scalarSubqueryExpr" | "scalar_subquery_expr" => Ok(GeneratedField::ScalarSubqueryExpr), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -12578,6 +12585,13 @@ impl<'de> serde::Deserialize<'de> for LogicalExprNode { return Err(serde::de::Error::duplicate_field("unnest")); } expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Unnest) +; + } + GeneratedField::ScalarSubqueryExpr => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("scalarSubqueryExpr")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::ScalarSubqueryExpr) ; } } @@ -16610,6 +16624,9 @@ impl serde::Serialize for PhysicalExprNode { physical_expr_node::ExprType::HashExpr(v) => { struct_ser.serialize_field("hashExpr", v)?; } + physical_expr_node::ExprType::ScalarSubquery(v) => { + struct_ser.serialize_field("scalarSubquery", v)?; + } } } struct_ser.end() @@ -16656,6 +16673,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { "unknownColumn", "hash_expr", "hashExpr", + "scalar_subquery", + "scalarSubquery", ]; #[allow(clippy::enum_variant_names)] @@ -16680,6 +16699,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { Extension, UnknownColumn, HashExpr, + ScalarSubquery, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -16721,6 +16741,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { "extension" => Ok(GeneratedField::Extension), "unknownColumn" | "unknown_column" => Ok(GeneratedField::UnknownColumn), "hashExpr" | "hash_expr" => Ok(GeneratedField::HashExpr), + "scalarSubquery" | "scalar_subquery" => Ok(GeneratedField::ScalarSubquery), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -16883,6 +16904,13 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { return Err(serde::de::Error::duplicate_field("hashExpr")); } expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::HashExpr) +; + } + GeneratedField::ScalarSubquery => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("scalarSubquery")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::ScalarSubquery) ; } } @@ -18121,6 +18149,9 @@ impl serde::Serialize for PhysicalPlanNode { physical_plan_node::PhysicalPlanType::ArrowScan(v) => { struct_ser.serialize_field("arrowScan", v)?; } + physical_plan_node::PhysicalPlanType::ScalarSubquery(v) => { + struct_ser.serialize_field("scalarSubquery", v)?; + } } } struct_ser.end() @@ -18191,6 +18222,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { "buffer", "arrow_scan", "arrowScan", + "scalar_subquery", + "scalarSubquery", ]; #[allow(clippy::enum_variant_names)] @@ -18232,6 +18265,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { AsyncFunc, Buffer, ArrowScan, + ScalarSubquery, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -18290,6 +18324,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { "asyncFunc" | "async_func" => Ok(GeneratedField::AsyncFunc), "buffer" => Ok(GeneratedField::Buffer), "arrowScan" | "arrow_scan" => Ok(GeneratedField::ArrowScan), + "scalarSubquery" | "scalar_subquery" => Ok(GeneratedField::ScalarSubquery), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -18569,6 +18604,13 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { return Err(serde::de::Error::duplicate_field("arrowScan")); } physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::ArrowScan) +; + } + GeneratedField::ScalarSubquery => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("scalarSubquery")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::ScalarSubquery) ; } } @@ -18581,6 +18623,134 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { deserializer.deserialize_struct("datafusion.PhysicalPlanNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for PhysicalScalarSubqueryExprNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.data_type.is_some() { + len += 1; + } + if self.nullable { + len += 1; + } + if self.index != 0 { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalScalarSubqueryExprNode", len)?; + if let Some(v) = self.data_type.as_ref() { + struct_ser.serialize_field("dataType", v)?; + } + if self.nullable { + struct_ser.serialize_field("nullable", &self.nullable)?; + } + if self.index != 0 { + struct_ser.serialize_field("index", &self.index)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for PhysicalScalarSubqueryExprNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "data_type", + "dataType", + "nullable", + "index", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + DataType, + Nullable, + Index, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl serde::de::Visitor<'_> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "dataType" | "data_type" => Ok(GeneratedField::DataType), + "nullable" => Ok(GeneratedField::Nullable), + "index" => Ok(GeneratedField::Index), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = PhysicalScalarSubqueryExprNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.PhysicalScalarSubqueryExprNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut data_type__ = None; + let mut nullable__ = None; + let mut index__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::DataType => { + if data_type__.is_some() { + return Err(serde::de::Error::duplicate_field("dataType")); + } + data_type__ = map_.next_value()?; + } + GeneratedField::Nullable => { + if nullable__.is_some() { + return Err(serde::de::Error::duplicate_field("nullable")); + } + nullable__ = Some(map_.next_value()?); + } + GeneratedField::Index => { + if index__.is_some() { + return Err(serde::de::Error::duplicate_field("index")); + } + index__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + } + } + Ok(PhysicalScalarSubqueryExprNode { + data_type: data_type__, + nullable: nullable__.unwrap_or_default(), + index: index__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.PhysicalScalarSubqueryExprNode", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for PhysicalScalarUdfNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -21250,6 +21420,205 @@ impl<'de> serde::Deserialize<'de> for RollupNode { deserializer.deserialize_struct("datafusion.RollupNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for ScalarSubqueryExecNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.input.is_some() { + len += 1; + } + if !self.subqueries.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.ScalarSubqueryExecNode", len)?; + if let Some(v) = self.input.as_ref() { + struct_ser.serialize_field("input", v)?; + } + if !self.subqueries.is_empty() { + struct_ser.serialize_field("subqueries", &self.subqueries)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for ScalarSubqueryExecNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "input", + "subqueries", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Input, + Subqueries, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl serde::de::Visitor<'_> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "input" => Ok(GeneratedField::Input), + "subqueries" => Ok(GeneratedField::Subqueries), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = ScalarSubqueryExecNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.ScalarSubqueryExecNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut input__ = None; + let mut subqueries__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); + } + input__ = map_.next_value()?; + } + GeneratedField::Subqueries => { + if subqueries__.is_some() { + return Err(serde::de::Error::duplicate_field("subqueries")); + } + subqueries__ = Some(map_.next_value()?); + } + } + } + Ok(ScalarSubqueryExecNode { + input: input__, + subqueries: subqueries__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.ScalarSubqueryExecNode", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for ScalarSubqueryExprNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.subquery.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.ScalarSubqueryExprNode", len)?; + if let Some(v) = self.subquery.as_ref() { + struct_ser.serialize_field("subquery", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for ScalarSubqueryExprNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "subquery", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Subquery, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl serde::de::Visitor<'_> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "subquery" => Ok(GeneratedField::Subquery), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = ScalarSubqueryExprNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.ScalarSubqueryExprNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut subquery__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Subquery => { + if subquery__.is_some() { + return Err(serde::de::Error::duplicate_field("subquery")); + } + subquery__ = map_.next_value()?; + } + } + } + Ok(ScalarSubqueryExprNode { + subquery: subquery__, + }) + } + } + deserializer.deserialize_struct("datafusion.ScalarSubqueryExprNode", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for ScalarUdfExprNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -22927,6 +23296,115 @@ impl<'de> serde::Deserialize<'de> for SubqueryAliasNode { deserializer.deserialize_struct("datafusion.SubqueryAliasNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for SubqueryNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.subquery.is_some() { + len += 1; + } + if !self.outer_ref_columns.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.SubqueryNode", len)?; + if let Some(v) = self.subquery.as_ref() { + struct_ser.serialize_field("subquery", v)?; + } + if !self.outer_ref_columns.is_empty() { + struct_ser.serialize_field("outerRefColumns", &self.outer_ref_columns)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for SubqueryNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "subquery", + "outer_ref_columns", + "outerRefColumns", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Subquery, + OuterRefColumns, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl serde::de::Visitor<'_> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "subquery" => Ok(GeneratedField::Subquery), + "outerRefColumns" | "outer_ref_columns" => Ok(GeneratedField::OuterRefColumns), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = SubqueryNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.SubqueryNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut subquery__ = None; + let mut outer_ref_columns__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Subquery => { + if subquery__.is_some() { + return Err(serde::de::Error::duplicate_field("subquery")); + } + subquery__ = map_.next_value()?; + } + GeneratedField::OuterRefColumns => { + if outer_ref_columns__.is_some() { + return Err(serde::de::Error::duplicate_field("outerRefColumns")); + } + outer_ref_columns__ = Some(map_.next_value()?); + } + } + } + Ok(SubqueryNode { + subquery: subquery__, + outer_ref_columns: outer_ref_columns__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.SubqueryNode", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for SymmetricHashJoinExecNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index ab60c3058dbde..1dc259926a87a 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -195,8 +195,8 @@ pub mod projection_node { pub struct SelectionNode { #[prost(message, optional, boxed, tag = "1")] pub input: ::core::option::Option<::prost::alloc::boxed::Box>, - #[prost(message, optional, tag = "2")] - pub expr: ::core::option::Option, + #[prost(message, optional, boxed, tag = "2")] + pub expr: ::core::option::Option<::prost::alloc::boxed::Box>, } #[derive(Clone, PartialEq, ::prost::Message)] pub struct SortNode { @@ -382,8 +382,8 @@ pub struct JoinNode { pub right_join_key: ::prost::alloc::vec::Vec, #[prost(enumeration = "super::datafusion_common::NullEquality", tag = "7")] pub null_equality: i32, - #[prost(message, optional, tag = "8")] - pub filter: ::core::option::Option, + #[prost(message, optional, boxed, tag = "8")] + pub filter: ::core::option::Option<::prost::alloc::boxed::Box>, } #[derive(Clone, PartialEq, ::prost::Message)] pub struct DistinctNode { @@ -578,7 +578,7 @@ pub struct SubqueryAliasNode { pub struct LogicalExprNode { #[prost( oneof = "logical_expr_node::ExprType", - tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 13, 14, 15, 17, 18, 19, 20, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35" + tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 13, 14, 15, 17, 18, 19, 20, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36" )] pub expr_type: ::core::option::Option, } @@ -656,6 +656,9 @@ pub mod logical_expr_node { Placeholder(super::PlaceholderNode), #[prost(message, tag = "35")] Unnest(super::Unnest), + /// Subquery expressions + #[prost(message, tag = "36")] + ScalarSubqueryExpr(::prost::alloc::boxed::Box), } } #[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] @@ -664,6 +667,18 @@ pub struct Wildcard { pub qualifier: ::core::option::Option, } #[derive(Clone, PartialEq, ::prost::Message)] +pub struct SubqueryNode { + #[prost(message, optional, boxed, tag = "1")] + pub subquery: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, repeated, tag = "2")] + pub outer_ref_columns: ::prost::alloc::vec::Vec, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ScalarSubqueryExprNode { + #[prost(message, optional, boxed, tag = "1")] + pub subquery: ::core::option::Option<::prost::alloc::boxed::Box>, +} +#[derive(Clone, PartialEq, ::prost::Message)] pub struct PlaceholderNode { #[prost(string, tag = "1")] pub id: ::prost::alloc::string::String, @@ -1102,7 +1117,7 @@ pub mod table_reference { pub struct PhysicalPlanNode { #[prost( oneof = "physical_plan_node::PhysicalPlanType", - tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38" + tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39" )] pub physical_plan_type: ::core::option::Option, } @@ -1186,6 +1201,8 @@ pub mod physical_plan_node { Buffer(::prost::alloc::boxed::Box), #[prost(message, tag = "38")] ArrowScan(super::ArrowScanExecNode), + #[prost(message, tag = "39")] + ScalarSubquery(::prost::alloc::boxed::Box), } } #[derive(Clone, PartialEq, ::prost::Message)] @@ -1317,7 +1334,7 @@ pub struct PhysicalExprNode { pub expr_id: ::core::option::Option, #[prost( oneof = "physical_expr_node::ExprType", - tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 16, 18, 19, 20, 21" + tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 16, 18, 19, 20, 21, 22" )] pub expr_type: ::core::option::Option, } @@ -1370,6 +1387,8 @@ pub mod physical_expr_node { UnknownColumn(super::UnknownColumn), #[prost(message, tag = "21")] HashExpr(super::PhysicalHashExprNode), + #[prost(message, tag = "22")] + ScalarSubquery(super::PhysicalScalarSubqueryExprNode), } } #[derive(Clone, PartialEq, ::prost::Message)] @@ -2202,6 +2221,22 @@ pub struct BufferExecNode { #[prost(uint64, tag = "2")] pub capacity: u64, } +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ScalarSubqueryExecNode { + #[prost(message, optional, boxed, tag = "1")] + pub input: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, repeated, tag = "2")] + pub subqueries: ::prost::alloc::vec::Vec, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct PhysicalScalarSubqueryExprNode { + #[prost(message, optional, tag = "1")] + pub data_type: ::core::option::Option, + #[prost(bool, tag = "2")] + pub nullable: bool, + #[prost(uint32, tag = "3")] + pub index: u32, +} /// Identifies a built-in file format supported by DataFusion. /// Used by DefaultLogicalExtensionCodec to serialize/deserialize /// FileFormatFactory instances (e.g. in CopyTo plans). diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index ed33d9fab1820..78ffd362c8e48 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -23,10 +23,12 @@ use datafusion_common::{ NullEquality, RecursionUnnestOption, Result, ScalarValue, TableReference, UnnestOptions, exec_datafusion_err, internal_err, plan_datafusion_err, }; +use datafusion_execution::TaskContext; use datafusion_execution::registry::FunctionRegistry; use datafusion_expr::dml::InsertOp; use datafusion_expr::expr::{Alias, NullTreatment, Placeholder, Sort}; use datafusion_expr::expr::{Unnest, WildcardOptions}; +use datafusion_expr::logical_plan::Subquery; use datafusion_expr::{ Between, BinaryExpr, Case, Cast, Expr, GroupingSet, GroupingSet::GroupingSets, @@ -52,7 +54,7 @@ use crate::protobuf::{ }, }; -use super::LogicalExtensionCodec; +use super::{AsLogicalPlan, LogicalExtensionCodec}; impl From<&protobuf::UnnestOptions> for UnnestOptions { fn from(opts: &protobuf::UnnestOptions) -> Self { @@ -256,7 +258,7 @@ impl From for NullTreatment { pub fn parse_expr( proto: &protobuf::LogicalExprNode, - registry: &dyn FunctionRegistry, + ctx: &TaskContext, codec: &dyn LogicalExtensionCodec, ) -> Result { use protobuf::{logical_expr_node::ExprType, window_expr_node}; @@ -269,7 +271,7 @@ pub fn parse_expr( match expr_type { ExprType::BinaryExpr(binary_expr) => { let op = from_proto_binary_op(&binary_expr.op)?; - let operands = parse_exprs(&binary_expr.operands, registry, codec)?; + let operands = parse_exprs(&binary_expr.operands, ctx, codec)?; if operands.len() < 2 { return Err(proto_error( @@ -296,8 +298,8 @@ pub fn parse_expr( .window_function .as_ref() .ok_or_else(|| Error::required("window_function"))?; - let partition_by = parse_exprs(&expr.partition_by, registry, codec)?; - let mut order_by = parse_sorts(&expr.order_by, registry, codec)?; + let partition_by = parse_exprs(&expr.partition_by, ctx, codec)?; + let mut order_by = parse_sorts(&expr.order_by, ctx, codec)?; let window_frame = expr .window_frame .as_ref() @@ -329,7 +331,7 @@ pub fn parse_expr( window_expr_node::WindowFunction::Udaf(udaf_name) => { let udaf_function = match &expr.fun_definition { Some(buf) => codec.try_decode_udaf(udaf_name, buf)?, - None => registry + None => ctx .udaf(udaf_name) .or_else(|_| codec.try_decode_udaf(udaf_name, &[]))?, }; @@ -338,7 +340,7 @@ pub fn parse_expr( window_expr_node::WindowFunction::Udwf(udwf_name) => { let udwf_function = match &expr.fun_definition { Some(buf) => codec.try_decode_udwf(udwf_name, buf)?, - None => registry + None => ctx .udwf(udwf_name) .or_else(|_| codec.try_decode_udwf(udwf_name, &[]))?, }; @@ -346,7 +348,7 @@ pub fn parse_expr( } }; - let args = parse_exprs(&expr.exprs, registry, codec)?; + let args = parse_exprs(&expr.exprs, ctx, codec)?; let mut builder = Expr::from(WindowFunction::new(agg_fn, args)) .partition_by(partition_by) .order_by(order_by) @@ -357,8 +359,7 @@ pub fn parse_expr( builder = builder.distinct(); }; - if let Some(filter) = - parse_optional_expr(expr.filter.as_deref(), registry, codec)? + if let Some(filter) = parse_optional_expr(expr.filter.as_deref(), ctx, codec)? { builder = builder.filter(filter); } @@ -366,7 +367,7 @@ pub fn parse_expr( builder.build().map_err(Error::DataFusionError) } ExprType::Alias(alias) => Ok(Expr::Alias(Alias::new( - parse_required_expr(alias.expr.as_deref(), registry, "expr", codec)?, + parse_required_expr(alias.expr.as_deref(), ctx, "expr", codec)?, alias .relation .first() @@ -376,69 +377,69 @@ pub fn parse_expr( ))), ExprType::IsNullExpr(is_null) => Ok(Expr::IsNull(Box::new(parse_required_expr( is_null.expr.as_deref(), - registry, + ctx, "expr", codec, )?))), ExprType::IsNotNullExpr(is_not_null) => Ok(Expr::IsNotNull(Box::new( - parse_required_expr(is_not_null.expr.as_deref(), registry, "expr", codec)?, + parse_required_expr(is_not_null.expr.as_deref(), ctx, "expr", codec)?, ))), ExprType::NotExpr(not) => Ok(Expr::Not(Box::new(parse_required_expr( not.expr.as_deref(), - registry, + ctx, "expr", codec, )?))), ExprType::IsTrue(msg) => Ok(Expr::IsTrue(Box::new(parse_required_expr( msg.expr.as_deref(), - registry, + ctx, "expr", codec, )?))), ExprType::IsFalse(msg) => Ok(Expr::IsFalse(Box::new(parse_required_expr( msg.expr.as_deref(), - registry, + ctx, "expr", codec, )?))), ExprType::IsUnknown(msg) => Ok(Expr::IsUnknown(Box::new(parse_required_expr( msg.expr.as_deref(), - registry, + ctx, "expr", codec, )?))), ExprType::IsNotTrue(msg) => Ok(Expr::IsNotTrue(Box::new(parse_required_expr( msg.expr.as_deref(), - registry, + ctx, "expr", codec, )?))), ExprType::IsNotFalse(msg) => Ok(Expr::IsNotFalse(Box::new(parse_required_expr( msg.expr.as_deref(), - registry, + ctx, "expr", codec, )?))), ExprType::IsNotUnknown(msg) => Ok(Expr::IsNotUnknown(Box::new( - parse_required_expr(msg.expr.as_deref(), registry, "expr", codec)?, + parse_required_expr(msg.expr.as_deref(), ctx, "expr", codec)?, ))), ExprType::Between(between) => Ok(Expr::Between(Between::new( Box::new(parse_required_expr( between.expr.as_deref(), - registry, + ctx, "expr", codec, )?), between.negated, Box::new(parse_required_expr( between.low.as_deref(), - registry, + ctx, "expr", codec, )?), Box::new(parse_required_expr( between.high.as_deref(), - registry, + ctx, "expr", codec, )?), @@ -447,13 +448,13 @@ pub fn parse_expr( like.negated, Box::new(parse_required_expr( like.expr.as_deref(), - registry, + ctx, "expr", codec, )?), Box::new(parse_required_expr( like.pattern.as_deref(), - registry, + ctx, "pattern", codec, )?), @@ -464,13 +465,13 @@ pub fn parse_expr( like.negated, Box::new(parse_required_expr( like.expr.as_deref(), - registry, + ctx, "expr", codec, )?), Box::new(parse_required_expr( like.pattern.as_deref(), - registry, + ctx, "pattern", codec, )?), @@ -481,13 +482,13 @@ pub fn parse_expr( like.negated, Box::new(parse_required_expr( like.expr.as_deref(), - registry, + ctx, "expr", codec, )?), Box::new(parse_required_expr( like.pattern.as_deref(), - registry, + ctx, "pattern", codec, )?), @@ -501,13 +502,13 @@ pub fn parse_expr( .map(|e| { let when_expr = parse_required_expr( e.when_expr.as_ref(), - registry, + ctx, "when_expr", codec, )?; let then_expr = parse_required_expr( e.then_expr.as_ref(), - registry, + ctx, "then_expr", codec, )?; @@ -515,16 +516,15 @@ pub fn parse_expr( }) .collect::, Box)>, Error>>()?; Ok(Expr::Case(Case::new( - parse_optional_expr(case.expr.as_deref(), registry, codec)?.map(Box::new), + parse_optional_expr(case.expr.as_deref(), ctx, codec)?.map(Box::new), when_then_expr, - parse_optional_expr(case.else_expr.as_deref(), registry, codec)? - .map(Box::new), + parse_optional_expr(case.else_expr.as_deref(), ctx, codec)?.map(Box::new), ))) } ExprType::Cast(cast) => { let expr = Box::new(parse_required_expr( cast.expr.as_deref(), - registry, + ctx, "expr", codec, )?); @@ -537,7 +537,7 @@ pub fn parse_expr( ExprType::TryCast(cast) => { let expr = Box::new(parse_required_expr( cast.expr.as_deref(), - registry, + ctx, "expr", codec, )?); @@ -551,10 +551,10 @@ pub fn parse_expr( ))) } ExprType::Negative(negative) => Ok(Expr::Negative(Box::new( - parse_required_expr(negative.expr.as_deref(), registry, "expr", codec)?, + parse_required_expr(negative.expr.as_deref(), ctx, "expr", codec)?, ))), ExprType::Unnest(unnest) => { - let mut exprs = parse_exprs(&unnest.exprs, registry, codec)?; + let mut exprs = parse_exprs(&unnest.exprs, ctx, codec)?; if exprs.len() != 1 { return Err(proto_error("Unnest must have exactly one expression")); } @@ -563,11 +563,11 @@ pub fn parse_expr( ExprType::InList(in_list) => Ok(Expr::InList(InList::new( Box::new(parse_required_expr( in_list.expr.as_deref(), - registry, + ctx, "expr", codec, )?), - parse_exprs(&in_list.list, registry, codec)?, + parse_exprs(&in_list.list, ctx, codec)?, in_list.negated, ))), ExprType::Wildcard(protobuf::Wildcard { qualifier }) => { @@ -585,19 +585,19 @@ pub fn parse_expr( }) => { let scalar_fn = match fun_definition { Some(buf) => codec.try_decode_udf(fun_name, buf)?, - None => registry + None => ctx .udf(fun_name.as_str()) .or_else(|_| codec.try_decode_udf(fun_name, &[]))?, }; Ok(Expr::ScalarFunction(expr::ScalarFunction::new_udf( scalar_fn, - parse_exprs(args, registry, codec)?, + parse_exprs(args, ctx, codec)?, ))) } ExprType::AggregateUdfExpr(pb) => { let agg_fn = match &pb.fun_definition { Some(buf) => codec.try_decode_udaf(&pb.fun_name, buf)?, - None => registry + None => ctx .udaf(&pb.fun_name) .or_else(|_| codec.try_decode_udaf(&pb.fun_name, &[]))?, }; @@ -616,10 +616,10 @@ pub fn parse_expr( Ok(Expr::AggregateFunction(expr::AggregateFunction::new_udf( agg_fn, - parse_exprs(&pb.args, registry, codec)?, + parse_exprs(&pb.args, ctx, codec)?, pb.distinct, - parse_optional_expr(pb.filter.as_deref(), registry, codec)?.map(Box::new), - parse_sorts(&pb.order_by, registry, codec)?, + parse_optional_expr(pb.filter.as_deref(), ctx, codec)?.map(Box::new), + parse_sorts(&pb.order_by, ctx, codec)?, null_treatment, ))) } @@ -627,15 +627,15 @@ pub fn parse_expr( ExprType::GroupingSet(GroupingSetNode { expr }) => { Ok(Expr::GroupingSet(GroupingSets( expr.iter() - .map(|expr_list| parse_exprs(&expr_list.expr, registry, codec)) + .map(|expr_list| parse_exprs(&expr_list.expr, ctx, codec)) .collect::, Error>>()?, ))) } ExprType::Cube(CubeNode { expr }) => Ok(Expr::GroupingSet(GroupingSet::Cube( - parse_exprs(expr, registry, codec)?, + parse_exprs(expr, ctx, codec)?, ))), ExprType::Rollup(RollupNode { expr }) => Ok(Expr::GroupingSet( - GroupingSet::Rollup(parse_exprs(expr, registry, codec)?), + GroupingSet::Rollup(parse_exprs(expr, ctx, codec)?), )), ExprType::Placeholder(PlaceholderNode { id, @@ -657,13 +657,41 @@ pub fn parse_expr( ))) } }, + ExprType::ScalarSubqueryExpr(sq) => { + let subquery = parse_subquery( + sq.subquery + .as_deref() + .ok_or_else(|| Error::required("ScalarSubqueryExprNode.subquery"))?, + ctx, + codec, + )?; + Ok(Expr::ScalarSubquery(subquery)) + } } } +fn parse_subquery( + proto: &protobuf::SubqueryNode, + ctx: &TaskContext, + codec: &dyn LogicalExtensionCodec, +) -> Result { + let plan_node = proto + .subquery + .as_ref() + .ok_or_else(|| Error::required("SubqueryNode.subquery"))?; + let plan = plan_node.try_into_logical_plan(ctx, codec)?; + let outer_ref_columns = parse_exprs(&proto.outer_ref_columns, ctx, codec)?; + Ok(Subquery { + subquery: Arc::new(plan), + outer_ref_columns, + spans: Default::default(), + }) +} + /// Parse a vector of `protobuf::LogicalExprNode`s. pub fn parse_exprs<'a, I>( protos: I, - registry: &dyn FunctionRegistry, + ctx: &TaskContext, codec: &dyn LogicalExtensionCodec, ) -> Result, Error> where @@ -672,7 +700,7 @@ where let res = protos .into_iter() .map(|elem| { - parse_expr(elem, registry, codec).map_err(|e| plan_datafusion_err!("{}", e)) + parse_expr(elem, ctx, codec).map_err(|e| plan_datafusion_err!("{}", e)) }) .collect::>>()?; Ok(res) @@ -680,7 +708,7 @@ where pub fn parse_sorts<'a, I>( protos: I, - registry: &dyn FunctionRegistry, + ctx: &TaskContext, codec: &dyn LogicalExtensionCodec, ) -> Result, Error> where @@ -688,17 +716,17 @@ where { protos .into_iter() - .map(|sort| parse_sort(sort, registry, codec)) + .map(|sort| parse_sort(sort, ctx, codec)) .collect::, Error>>() } pub fn parse_sort( sort: &protobuf::SortExprNode, - registry: &dyn FunctionRegistry, + ctx: &TaskContext, codec: &dyn LogicalExtensionCodec, ) -> Result { Ok(Sort::new( - parse_required_expr(sort.expr.as_ref(), registry, "expr", codec)?, + parse_required_expr(sort.expr.as_ref(), ctx, "expr", codec)?, sort.asc, sort.nulls_first, )) @@ -754,23 +782,23 @@ pub fn from_proto_binary_op(op: &str) -> Result { fn parse_optional_expr( p: Option<&protobuf::LogicalExprNode>, - registry: &dyn FunctionRegistry, + ctx: &TaskContext, codec: &dyn LogicalExtensionCodec, ) -> Result, Error> { match p { - Some(expr) => parse_expr(expr, registry, codec).map(Some), + Some(expr) => parse_expr(expr, ctx, codec).map(Some), None => Ok(None), } } fn parse_required_expr( p: Option<&protobuf::LogicalExprNode>, - registry: &dyn FunctionRegistry, + ctx: &TaskContext, field: impl Into, codec: &dyn LogicalExtensionCodec, ) -> Result { match p { - Some(expr) => parse_expr(expr, registry, codec), + Some(expr) => parse_expr(expr, ctx, codec), None => Err(Error::required(field)), } } diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index 9715ecf8d97ce..eea2eb3364b2c 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -1321,10 +1321,10 @@ impl AsLogicalPlan for LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Selection(Box::new( protobuf::SelectionNode { input: Some(Box::new(input)), - expr: Some(serialize_expr( + expr: Some(Box::new(serialize_expr( &filter.predicate, extension_codec, - )?), + )?)), }, ))), }) @@ -1440,7 +1440,7 @@ impl AsLogicalPlan for LogicalPlanNode { null_equality.to_owned().into(); let filter = filter .as_ref() - .map(|e| serialize_expr(e, extension_codec)) + .map(|e| serialize_expr(e, extension_codec).map(Box::new)) .map_or(Ok(None), |v| v.map(Some))?; Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Join(Box::new( @@ -1457,8 +1457,14 @@ impl AsLogicalPlan for LogicalPlanNode { ))), }) } - LogicalPlan::Subquery(_) => { - not_impl_err!("LogicalPlan serde is not yet implemented for subqueries") + LogicalPlan::Subquery(subquery) => { + // Serialize the inner subquery plan directly — the + // LogicalPlan::Subquery wrapper is reconstructed during + // expression deserialization. + LogicalPlanNode::try_from_logical_plan( + &subquery.subquery, + extension_codec, + ) } LogicalPlan::SubqueryAlias(SubqueryAlias { input, alias, .. }) => { let input: LogicalPlanNode = LogicalPlanNode::try_from_logical_plan( diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 6fcb7389922ad..bd5c4b585c24f 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -28,6 +28,7 @@ use datafusion_expr::expr::{ self, AggregateFunctionParams, Alias, Between, BinaryExpr, Cast, GroupingSet, InList, Like, NullTreatment, Placeholder, ScalarFunction, Unnest, }; +use datafusion_expr::logical_plan::Subquery; use datafusion_expr::{ Expr, JoinConstraint, JoinType, SortExpr, TryCast, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, logical_plan::PlanType, @@ -48,7 +49,8 @@ use crate::protobuf::{ }, }; -use super::LogicalExtensionCodec; +use super::{AsLogicalPlan, LogicalExtensionCodec}; +use crate::protobuf::LogicalPlanNode; impl From<&UnnestOptions> for protobuf::UnnestOptions { fn from(opts: &UnnestOptions) -> Self { @@ -579,14 +581,20 @@ pub fn serialize_expr( qualifier: qualifier.to_owned().map(|x| x.into()), })), }, - Expr::ScalarSubquery(_) - | Expr::InSubquery(_) - | Expr::Exists { .. } - | Expr::SetComparison(_) - | Expr::OuterReferenceColumn { .. } => { - // we would need to add logical plan operators to datafusion.proto to support this - // see discussion in https://github.com/apache/datafusion/issues/2565 - return Err(Error::General("Proto serialization error: Expr::ScalarSubquery(_) | Expr::InSubquery(_) | Expr::Exists { .. } | Exp:OuterReferenceColumn not supported".to_string())); + Expr::ScalarSubquery(subquery) => protobuf::LogicalExprNode { + expr_type: Some(ExprType::ScalarSubqueryExpr(Box::new( + protobuf::ScalarSubqueryExprNode { + subquery: Some(Box::new(serialize_subquery(subquery, codec)?)), + }, + ))), + }, + Expr::InSubquery(_) + | Expr::Exists(_) + | Expr::OuterReferenceColumn(_, _) + | Expr::SetComparison(_) => { + return Err(Error::General(format!( + "Proto serialization error: {expr} is not yet supported" + ))); } Expr::GroupingSet(GroupingSet::Cube(exprs)) => protobuf::LogicalExprNode { expr_type: Some(ExprType::Cube(CubeNode { @@ -631,6 +639,19 @@ pub fn serialize_expr( Ok(expr_node) } +fn serialize_subquery( + subquery: &Subquery, + codec: &dyn LogicalExtensionCodec, +) -> Result { + let plan = LogicalPlanNode::try_from_logical_plan(&subquery.subquery, codec) + .map_err(|e| Error::General(e.to_string()))?; + let outer_ref_columns = serialize_exprs(&subquery.outer_ref_columns, codec)?; + Ok(protobuf::SubqueryNode { + subquery: Some(Box::new(plan)), + outer_ref_columns, + }) +} + pub fn serialize_sorts<'a, I>( sorts: I, codec: &dyn LogicalExtensionCodec, diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index 61665db607757..732676941bebb 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -38,7 +38,9 @@ use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_execution::{FunctionRegistry, TaskContext}; use datafusion_expr::WindowFunctionDefinition; use datafusion_expr::dml::InsertOp; +use datafusion_expr::execution_props::SubqueryIndex; use datafusion_physical_expr::projection::{ProjectionExpr, ProjectionExprs}; +use datafusion_physical_expr::scalar_subquery::ScalarSubqueryExpr; use datafusion_physical_expr::{LexOrdering, PhysicalSortExpr, ScalarFunctionExpr}; use datafusion_physical_plan::expressions::{ BinaryExpr, CaseExpr, CastExpr, Column, IsNotNullExpr, IsNullExpr, LikeExpr, Literal, @@ -52,7 +54,7 @@ use object_store::ObjectMeta; use object_store::path::Path; use super::{ - DefaultPhysicalProtoConverter, PhysicalExtensionCodec, + DefaultPhysicalProtoConverter, PhysicalExtensionCodec, PhysicalPlanDecodeContext, PhysicalProtoConverterExtension, }; use crate::logical_plan::{self}; @@ -70,24 +72,21 @@ impl From<&protobuf::PhysicalColumn> for Column { /// # Arguments /// /// * `proto` - Input proto with physical sort expression node -/// * `registry` - A registry knows how to build logical expressions out of user-defined function names /// * `input_schema` - The Arrow schema for the input, used for determining expression data types /// when performing type coercion. -/// * `codec` - An extension codec used to decode custom UDFs. +/// * `ctx` - Decode context carrying the task context, extension codec, and +/// any scoped state needed during recursive deserialization. +/// * `proto_converter` - Converter hooks used for recursive physical plan and +/// expression deserialization. pub fn parse_physical_sort_expr( proto: &protobuf::PhysicalSortExprNode, - ctx: &TaskContext, + ctx: &PhysicalPlanDecodeContext<'_>, input_schema: &Schema, - codec: &dyn PhysicalExtensionCodec, proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { if let Some(expr) = &proto.expr { - let expr = proto_converter.proto_to_physical_expr( - expr.as_ref(), - ctx, - input_schema, - codec, - )?; + let expr = + proto_converter.proto_to_physical_expr(expr.as_ref(), input_schema, ctx)?; let options = SortOptions { descending: !proto.asc, nulls_first: proto.nulls_first, @@ -103,21 +102,22 @@ pub fn parse_physical_sort_expr( /// # Arguments /// /// * `proto` - Input proto with vector of physical sort expression node -/// * `registry` - A registry knows how to build logical expressions out of user-defined function names /// * `input_schema` - The Arrow schema for the input, used for determining expression data types /// when performing type coercion. -/// * `codec` - An extension codec used to decode custom UDFs. +/// * `ctx` - Decode context carrying the task context, extension codec, and +/// any scoped state needed during recursive deserialization. +/// * `proto_converter` - Converter hooks used for recursive physical plan and +/// expression deserialization. pub fn parse_physical_sort_exprs( proto: &[protobuf::PhysicalSortExprNode], - ctx: &TaskContext, + ctx: &PhysicalPlanDecodeContext<'_>, input_schema: &Schema, - codec: &dyn PhysicalExtensionCodec, proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { proto .iter() .map(|sort_expr| { - parse_physical_sort_expr(sort_expr, ctx, input_schema, codec, proto_converter) + parse_physical_sort_expr(sort_expr, ctx, input_schema, proto_converter) }) .collect() } @@ -128,34 +128,25 @@ pub fn parse_physical_sort_exprs( /// /// * `proto` - Input proto with physical window expression node. /// * `name` - Name of the window expression. -/// * `registry` - A registry knows how to build logical expressions out of user-defined function names -/// * `input_schema` - The Arrow schema for the input, used for determining expression data types -/// when performing type coercion. -/// * `codec` - An extension codec used to decode custom UDFs. +/// * `input_schema` - The Arrow schema for the input, used for determining +/// expression data types when performing type coercion. +/// * `ctx` - Decode context carrying the task context, extension codec, and +/// any scoped state needed during recursive deserialization. +/// * `proto_converter` - Converter hooks used for recursive physical plan and +/// expression deserialization. pub fn parse_physical_window_expr( proto: &protobuf::PhysicalWindowExprNode, - ctx: &TaskContext, + ctx: &PhysicalPlanDecodeContext<'_>, input_schema: &Schema, - codec: &dyn PhysicalExtensionCodec, proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let window_node_expr = - parse_physical_exprs(&proto.args, ctx, input_schema, codec, proto_converter)?; - let partition_by = parse_physical_exprs( - &proto.partition_by, - ctx, - input_schema, - codec, - proto_converter, - )?; + parse_physical_exprs(&proto.args, ctx, input_schema, proto_converter)?; + let partition_by = + parse_physical_exprs(&proto.partition_by, ctx, input_schema, proto_converter)?; - let order_by = parse_physical_sort_exprs( - &proto.order_by, - ctx, - input_schema, - codec, - proto_converter, - )?; + let order_by = + parse_physical_sort_exprs(&proto.order_by, ctx, input_schema, proto_converter)?; let window_frame = proto .window_frame @@ -171,14 +162,20 @@ pub fn parse_physical_window_expr( match window_func { protobuf::physical_window_expr_node::WindowFunction::UserDefinedAggrFunction(udaf_name) => { WindowFunctionDefinition::AggregateUDF(match &proto.fun_definition { - Some(buf) => codec.try_decode_udaf(udaf_name, buf)?, - None => ctx.udaf(udaf_name).or_else(|_| codec.try_decode_udaf(udaf_name, &[]))?, + Some(buf) => ctx.codec().try_decode_udaf(udaf_name, buf)?, + None => ctx + .task_ctx() + .udaf(udaf_name) + .or_else(|_| ctx.codec().try_decode_udaf(udaf_name, &[]))?, }) } protobuf::physical_window_expr_node::WindowFunction::UserDefinedWindowFunction(udwf_name) => { WindowFunctionDefinition::WindowUDF(match &proto.fun_definition { - Some(buf) => codec.try_decode_udwf(udwf_name, buf)?, - None => ctx.udwf(udwf_name).or_else(|_| codec.try_decode_udwf(udwf_name, &[]))? + Some(buf) => ctx.codec().try_decode_udwf(udwf_name, buf)?, + None => ctx + .task_ctx() + .udwf(udwf_name) + .or_else(|_| ctx.codec().try_decode_udwf(udwf_name, &[]))? }) } } @@ -206,9 +203,8 @@ pub fn parse_physical_window_expr( pub fn parse_physical_exprs<'a, I>( protos: I, - ctx: &TaskContext, + ctx: &PhysicalPlanDecodeContext<'_>, input_schema: &Schema, - codec: &dyn PhysicalExtensionCodec, proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result>> where @@ -216,7 +212,7 @@ where { protos .into_iter() - .map(|p| proto_converter.proto_to_physical_expr(p, ctx, input_schema, codec)) + .map(|p| proto_converter.proto_to_physical_expr(p, input_schema, ctx)) .collect::>>() } @@ -225,21 +221,22 @@ where /// # Arguments /// /// * `proto` - Input proto with physical expression node -/// * `registry` - A registry knows how to build logical expressions out of user-defined function names -/// * `input_schema` - The Arrow schema for the input, used for determining expression data types -/// when performing type coercion. -/// * `codec` - An extension codec used to decode custom UDFs. +/// * `ctx` - Task context used to resolve registered functions. +/// * `input_schema` - The Arrow schema for the input, used for determining +/// expression data types when performing type coercion. +/// * `codec` - Physical extension codec used to construct the root decode +/// context for deserialization. pub fn parse_physical_expr( proto: &protobuf::PhysicalExprNode, ctx: &TaskContext, input_schema: &Schema, codec: &dyn PhysicalExtensionCodec, ) -> Result> { + let decode_ctx = PhysicalPlanDecodeContext::new(ctx, codec); parse_physical_expr_with_converter( proto, - ctx, input_schema, - codec, + &decode_ctx, &DefaultPhysicalProtoConverter {}, ) } @@ -249,16 +246,16 @@ pub fn parse_physical_expr( /// # Arguments /// /// * `proto` - Input proto with physical expression node -/// * `registry` - A registry knows how to build logical expressions out of user-defined function names -/// * `input_schema` - The Arrow schema for the input, used for determining expression data types -/// when performing type coercion. -/// * `codec` - An extension codec used to decode custom UDFs. -/// * `proto_converter` - Conversion functions for physical plans and expressions +/// * `input_schema` - The Arrow schema for the input, used for determining +/// expression data types when performing type coercion. +/// * `ctx` - Decode context carrying the task context, extension codec, and +/// any scoped state needed during recursive deserialization. +/// * `proto_converter` - Converter hooks used for recursive physical plan and +/// expression deserialization. pub fn parse_physical_expr_with_converter( proto: &protobuf::PhysicalExprNode, - ctx: &TaskContext, input_schema: &Schema, - codec: &dyn PhysicalExtensionCodec, + ctx: &PhysicalPlanDecodeContext<'_>, proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let expr_type = proto @@ -281,14 +278,7 @@ pub fn parse_physical_expr_with_converter( let operands: Vec> = binary_expr .operands .iter() - .map(|e| { - proto_converter.proto_to_physical_expr( - e, - ctx, - input_schema, - codec, - ) - }) + .map(|e| proto_converter.proto_to_physical_expr(e, input_schema, ctx)) .collect::>>()?; if operands.len() < 2 { @@ -311,7 +301,6 @@ pub fn parse_physical_expr_with_converter( ctx, "left", input_schema, - codec, proto_converter, )?, op, @@ -320,7 +309,6 @@ pub fn parse_physical_expr_with_converter( ctx, "right", input_schema, - codec, proto_converter, )?, )) @@ -345,7 +333,6 @@ pub fn parse_physical_expr_with_converter( ctx, "expr", input_schema, - codec, proto_converter, )?)) } @@ -355,7 +342,6 @@ pub fn parse_physical_expr_with_converter( ctx, "expr", input_schema, - codec, proto_converter, )?)) } @@ -364,7 +350,6 @@ pub fn parse_physical_expr_with_converter( ctx, "expr", input_schema, - codec, proto_converter, )?)), ExprType::Negative(e) => { @@ -373,7 +358,6 @@ pub fn parse_physical_expr_with_converter( ctx, "expr", input_schema, - codec, proto_converter, )?)) } @@ -383,10 +367,9 @@ pub fn parse_physical_expr_with_converter( ctx, "expr", input_schema, - codec, proto_converter, )?, - parse_physical_exprs(&e.list, ctx, input_schema, codec, proto_converter)?, + parse_physical_exprs(&e.list, ctx, input_schema, proto_converter)?, &e.negated, input_schema, )?, @@ -394,12 +377,7 @@ pub fn parse_physical_expr_with_converter( e.expr .as_ref() .map(|e| { - proto_converter.proto_to_physical_expr( - e.as_ref(), - ctx, - input_schema, - codec, - ) + proto_converter.proto_to_physical_expr(e.as_ref(), input_schema, ctx) }) .transpose()?, e.when_then_expr @@ -411,7 +389,6 @@ pub fn parse_physical_expr_with_converter( ctx, "when_expr", input_schema, - codec, proto_converter, )?, parse_required_physical_expr( @@ -419,7 +396,6 @@ pub fn parse_physical_expr_with_converter( ctx, "then_expr", input_schema, - codec, proto_converter, )?, )) @@ -428,12 +404,7 @@ pub fn parse_physical_expr_with_converter( e.else_expr .as_ref() .map(|e| { - proto_converter.proto_to_physical_expr( - e.as_ref(), - ctx, - input_schema, - codec, - ) + proto_converter.proto_to_physical_expr(e.as_ref(), input_schema, ctx) }) .transpose()?, )?), @@ -443,7 +414,6 @@ pub fn parse_physical_expr_with_converter( ctx, "expr", input_schema, - codec, proto_converter, )?, convert_required!(e.arrow_type)?, @@ -455,24 +425,23 @@ pub fn parse_physical_expr_with_converter( ctx, "expr", input_schema, - codec, proto_converter, )?, convert_required!(e.arrow_type)?, )), ExprType::ScalarUdf(e) => { let udf = match &e.fun_definition { - Some(buf) => codec.try_decode_udf(&e.name, buf)?, + Some(buf) => ctx.codec().try_decode_udf(&e.name, buf)?, None => ctx + .task_ctx() .udf(e.name.as_str()) - .or_else(|_| codec.try_decode_udf(&e.name, &[]))?, + .or_else(|_| ctx.codec().try_decode_udf(&e.name, &[]))?, }; let scalar_fun_def = Arc::clone(&udf); - let args = - parse_physical_exprs(&e.args, ctx, input_schema, codec, proto_converter)?; + let args = parse_physical_exprs(&e.args, ctx, input_schema, proto_converter)?; - let config_options = Arc::clone(ctx.session_config().options()); + let config_options = Arc::clone(ctx.task_ctx().session_config().options()); Arc::new( ScalarFunctionExpr::new( @@ -498,7 +467,6 @@ pub fn parse_physical_expr_with_converter( ctx, "expr", input_schema, - codec, proto_converter, )?, parse_required_physical_expr( @@ -506,7 +474,6 @@ pub fn parse_physical_expr_with_converter( ctx, "pattern", input_schema, - codec, proto_converter, )?, )), @@ -515,7 +482,6 @@ pub fn parse_physical_expr_with_converter( &hash_expr.on_columns, ctx, input_schema, - codec, proto_converter, )?; Arc::new(HashExpr::new( @@ -524,15 +490,35 @@ pub fn parse_physical_expr_with_converter( hash_expr.description.clone(), )) } + ExprType::ScalarSubquery(sq) => { + let data_type: arrow::datatypes::DataType = sq + .data_type + .as_ref() + .ok_or_else(|| { + proto_error("Missing data_type in PhysicalScalarSubqueryExprNode") + })? + .try_into()?; + let results = ctx.scalar_subquery_results().ok_or_else(|| { + proto_error( + "ScalarSubqueryExpr can only be deserialized as part \ + of a surrounding ScalarSubqueryExec", + ) + })?; + Arc::new(ScalarSubqueryExpr::new( + data_type, + sq.nullable, + SubqueryIndex::new(sq.index as usize), + results.clone(), + )) + } ExprType::Extension(extension) => { let inputs: Vec> = extension .inputs .iter() - .map(|e| { - proto_converter.proto_to_physical_expr(e, ctx, input_schema, codec) - }) + .map(|e| proto_converter.proto_to_physical_expr(e, input_schema, ctx)) .collect::>()?; - codec.try_decode_expr(extension.expr.as_slice(), &inputs)? as _ + ctx.codec() + .try_decode_expr(extension.expr.as_slice(), &inputs)? as _ } }; @@ -541,22 +527,20 @@ pub fn parse_physical_expr_with_converter( fn parse_required_physical_expr( expr: Option<&protobuf::PhysicalExprNode>, - ctx: &TaskContext, + ctx: &PhysicalPlanDecodeContext<'_>, field: &str, input_schema: &Schema, - codec: &dyn PhysicalExtensionCodec, proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { - expr.map(|e| proto_converter.proto_to_physical_expr(e, ctx, input_schema, codec)) + expr.map(|e| proto_converter.proto_to_physical_expr(e, input_schema, ctx)) .transpose()? .ok_or_else(|| internal_datafusion_err!("Missing required field {field:?}")) } pub fn parse_protobuf_hash_partitioning( partitioning: Option<&protobuf::PhysicalHashRepartition>, - ctx: &TaskContext, + ctx: &PhysicalPlanDecodeContext<'_>, input_schema: &Schema, - codec: &dyn PhysicalExtensionCodec, proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { match partitioning { @@ -565,7 +549,6 @@ pub fn parse_protobuf_hash_partitioning( &hash_part.hash_expr, ctx, input_schema, - codec, proto_converter, )?; @@ -580,9 +563,8 @@ pub fn parse_protobuf_hash_partitioning( pub fn parse_protobuf_partitioning( partitioning: Option<&protobuf::Partitioning>, - ctx: &TaskContext, + ctx: &PhysicalPlanDecodeContext<'_>, input_schema: &Schema, - codec: &dyn PhysicalExtensionCodec, proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { match partitioning { @@ -597,7 +579,6 @@ pub fn parse_protobuf_partitioning( Some(hash_repartition), ctx, input_schema, - codec, proto_converter, ) } @@ -651,8 +632,7 @@ pub fn parse_table_schema_from_proto( pub fn parse_protobuf_file_scan_config( proto: &protobuf::FileScanExecConf, - ctx: &TaskContext, - codec: &dyn PhysicalExtensionCodec, + ctx: &PhysicalPlanDecodeContext<'_>, proto_converter: &dyn PhysicalProtoConverterExtension, file_source: Arc, ) -> Result { @@ -678,7 +658,6 @@ pub fn parse_protobuf_file_scan_config( &node_collection.physical_sort_expr_nodes, ctx, &schema, - codec, proto_converter, )?; output_ordering.extend(LexOrdering::new(sort_exprs)); @@ -694,9 +673,8 @@ pub fn parse_protobuf_file_scan_config( proto_expr.expr.as_ref().ok_or_else(|| { internal_datafusion_err!("ProjectionExpr missing expr field") })?, - ctx, &schema, - codec, + ctx, )?; Ok(ProjectionExpr::new(expr, proto_expr.alias.clone())) }) diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 206f4378d3d3b..0d169ab82c438 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -51,6 +51,7 @@ use datafusion_datasource_parquet::source::ParquetSource; #[cfg(feature = "parquet")] use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_execution::{FunctionRegistry, TaskContext}; +use datafusion_expr::execution_props::{ScalarSubqueryResults, SubqueryIndex}; use datafusion_expr::{AggregateUDF, ScalarUDF, WindowUDF}; use datafusion_functions_table::generate_series::{ Empty, GenSeriesArgs, GenerateSeriesTable, GenericSeriesState, TimestampValue, @@ -83,6 +84,7 @@ use datafusion_physical_plan::metrics::{MetricCategory, MetricType}; use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; use datafusion_physical_plan::projection::{ProjectionExec, ProjectionExpr}; use datafusion_physical_plan::repartition::RepartitionExec; +use datafusion_physical_plan::scalar_subquery::{ScalarSubqueryExec, ScalarSubqueryLink}; use datafusion_physical_plan::sorts::sort::SortExec; use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use datafusion_physical_plan::union::{InterleaveExec, UnionExec}; @@ -117,6 +119,58 @@ use crate::{convert_required, into_required}; pub mod from_proto; pub mod to_proto; +/// Context threaded through physical-plan deserialization. +/// +/// This bundles the stable per-call inputs for deserialization and the +/// per-scope `ScalarSubqueryResults` handle needed while reconstructing +/// `ScalarSubqueryExpr` nodes inside a `ScalarSubqueryExec` input plan. +#[derive(Clone)] +pub struct PhysicalPlanDecodeContext<'a> { + task_ctx: &'a TaskContext, + codec: &'a dyn PhysicalExtensionCodec, + scalar_subquery_results: Option, +} + +impl<'a> PhysicalPlanDecodeContext<'a> { + /// Creates a new root decode context. + pub fn new(task_ctx: &'a TaskContext, codec: &'a dyn PhysicalExtensionCodec) -> Self { + Self { + task_ctx, + codec, + scalar_subquery_results: None, + } + } + + /// Returns the task context used for deserialization. + pub fn task_ctx(&self) -> &'a TaskContext { + self.task_ctx + } + + /// Returns the physical extension codec used for deserialization. + pub fn codec(&self) -> &'a dyn PhysicalExtensionCodec { + self.codec + } + + /// Returns the scalar subquery results container for the current scope, if + /// one is active. + pub fn scalar_subquery_results(&self) -> Option<&ScalarSubqueryResults> { + self.scalar_subquery_results.as_ref() + } + + /// Returns a child context with a different scalar subquery results + /// container. + pub fn with_scalar_subquery_results( + &self, + scalar_subquery_results: ScalarSubqueryResults, + ) -> Self { + Self { + task_ctx: self.task_ctx, + codec: self.codec, + scalar_subquery_results: Some(scalar_subquery_results), + } + } +} + impl AsExecutionPlan for protobuf::PhysicalPlanNode { fn try_decode(buf: &[u8]) -> Result where @@ -168,9 +222,17 @@ impl protobuf::PhysicalPlanNode { pub fn try_into_physical_plan_with_converter( &self, ctx: &TaskContext, - codec: &dyn PhysicalExtensionCodec, proto_converter: &dyn PhysicalProtoConverterExtension, + ) -> Result> { + let decode_ctx = PhysicalPlanDecodeContext::new(ctx, codec); + self.try_into_physical_plan_with_context(&decode_ctx, proto_converter) + } + + pub(crate) fn try_into_physical_plan_with_context( + &self, + ctx: &PhysicalPlanDecodeContext<'_>, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let plan = self.physical_plan_type.as_ref().ok_or_else(|| { proto_error(format!( @@ -179,147 +241,124 @@ impl protobuf::PhysicalPlanNode { })?; match plan { PhysicalPlanType::Explain(explain) => { - self.try_into_explain_physical_plan(explain, ctx, codec, proto_converter) + self.try_into_explain_physical_plan(explain, ctx, proto_converter) + } + PhysicalPlanType::Projection(projection) => { + self.try_into_projection_physical_plan(projection, ctx, proto_converter) } - PhysicalPlanType::Projection(projection) => self - .try_into_projection_physical_plan( - projection, - ctx, - codec, - proto_converter, - ), PhysicalPlanType::Filter(filter) => { - self.try_into_filter_physical_plan(filter, ctx, codec, proto_converter) + self.try_into_filter_physical_plan(filter, ctx, proto_converter) } PhysicalPlanType::CsvScan(scan) => { - self.try_into_csv_scan_physical_plan(scan, ctx, codec, proto_converter) + self.try_into_csv_scan_physical_plan(scan, ctx, proto_converter) } PhysicalPlanType::JsonScan(scan) => { - self.try_into_json_scan_physical_plan(scan, ctx, codec, proto_converter) + self.try_into_json_scan_physical_plan(scan, ctx, proto_converter) + } + PhysicalPlanType::ParquetScan(scan) => { + self.try_into_parquet_scan_physical_plan(scan, ctx, proto_converter) } - PhysicalPlanType::ParquetScan(scan) => self - .try_into_parquet_scan_physical_plan(scan, ctx, codec, proto_converter), PhysicalPlanType::AvroScan(scan) => { - self.try_into_avro_scan_physical_plan(scan, ctx, codec, proto_converter) + self.try_into_avro_scan_physical_plan(scan, ctx, proto_converter) } PhysicalPlanType::MemoryScan(scan) => { - self.try_into_memory_scan_physical_plan(scan, ctx, codec, proto_converter) + self.try_into_memory_scan_physical_plan(scan, ctx, proto_converter) } PhysicalPlanType::ArrowScan(scan) => { - self.try_into_arrow_scan_physical_plan(scan, ctx, codec, proto_converter) + self.try_into_arrow_scan_physical_plan(scan, ctx, proto_converter) } PhysicalPlanType::CoalesceBatches(coalesce_batches) => self .try_into_coalesce_batches_physical_plan( coalesce_batches, ctx, - codec, proto_converter, ), PhysicalPlanType::Merge(merge) => { - self.try_into_merge_physical_plan(merge, ctx, codec, proto_converter) + self.try_into_merge_physical_plan(merge, ctx, proto_converter) + } + PhysicalPlanType::Repartition(repart) => { + self.try_into_repartition_physical_plan(repart, ctx, proto_converter) + } + PhysicalPlanType::GlobalLimit(limit) => { + self.try_into_global_limit_physical_plan(limit, ctx, proto_converter) + } + PhysicalPlanType::LocalLimit(limit) => { + self.try_into_local_limit_physical_plan(limit, ctx, proto_converter) + } + PhysicalPlanType::Window(window_agg) => { + self.try_into_window_physical_plan(window_agg, ctx, proto_converter) + } + PhysicalPlanType::Aggregate(hash_agg) => { + self.try_into_aggregate_physical_plan(hash_agg, ctx, proto_converter) + } + PhysicalPlanType::HashJoin(hashjoin) => { + self.try_into_hash_join_physical_plan(hashjoin, ctx, proto_converter) } - PhysicalPlanType::Repartition(repart) => self - .try_into_repartition_physical_plan(repart, ctx, codec, proto_converter), - PhysicalPlanType::GlobalLimit(limit) => self - .try_into_global_limit_physical_plan(limit, ctx, codec, proto_converter), - PhysicalPlanType::LocalLimit(limit) => self - .try_into_local_limit_physical_plan(limit, ctx, codec, proto_converter), - PhysicalPlanType::Window(window_agg) => self.try_into_window_physical_plan( - window_agg, - ctx, - codec, - proto_converter, - ), - PhysicalPlanType::Aggregate(hash_agg) => self - .try_into_aggregate_physical_plan(hash_agg, ctx, codec, proto_converter), - PhysicalPlanType::HashJoin(hashjoin) => self - .try_into_hash_join_physical_plan(hashjoin, ctx, codec, proto_converter), PhysicalPlanType::SymmetricHashJoin(sym_join) => self .try_into_symmetric_hash_join_physical_plan( sym_join, ctx, - codec, proto_converter, ), PhysicalPlanType::Union(union) => { - self.try_into_union_physical_plan(union, ctx, codec, proto_converter) + self.try_into_union_physical_plan(union, ctx, proto_converter) + } + PhysicalPlanType::Interleave(interleave) => { + self.try_into_interleave_physical_plan(interleave, ctx, proto_converter) + } + PhysicalPlanType::CrossJoin(crossjoin) => { + self.try_into_cross_join_physical_plan(crossjoin, ctx, proto_converter) } - PhysicalPlanType::Interleave(interleave) => self - .try_into_interleave_physical_plan( - interleave, - ctx, - codec, - proto_converter, - ), - PhysicalPlanType::CrossJoin(crossjoin) => self - .try_into_cross_join_physical_plan( - crossjoin, - ctx, - codec, - proto_converter, - ), PhysicalPlanType::Empty(empty) => { - self.try_into_empty_physical_plan(empty, ctx, codec, proto_converter) + self.try_into_empty_physical_plan(empty, ctx, proto_converter) } PhysicalPlanType::PlaceholderRow(placeholder) => { - self.try_into_placeholder_row_physical_plan(placeholder, ctx, codec) + self.try_into_placeholder_row_physical_plan(placeholder, ctx) } PhysicalPlanType::Sort(sort) => { - self.try_into_sort_physical_plan(sort, ctx, codec, proto_converter) + self.try_into_sort_physical_plan(sort, ctx, proto_converter) } PhysicalPlanType::SortPreservingMerge(sort) => self - .try_into_sort_preserving_merge_physical_plan( - sort, - ctx, - codec, - proto_converter, - ), - PhysicalPlanType::Extension(extension) => self - .try_into_extension_physical_plan(extension, ctx, codec, proto_converter), - PhysicalPlanType::NestedLoopJoin(join) => self - .try_into_nested_loop_join_physical_plan( - join, - ctx, - codec, - proto_converter, - ), + .try_into_sort_preserving_merge_physical_plan(sort, ctx, proto_converter), + PhysicalPlanType::Extension(extension) => { + self.try_into_extension_physical_plan(extension, ctx, proto_converter) + } + PhysicalPlanType::NestedLoopJoin(join) => { + self.try_into_nested_loop_join_physical_plan(join, ctx, proto_converter) + } PhysicalPlanType::Analyze(analyze) => { - self.try_into_analyze_physical_plan(analyze, ctx, codec, proto_converter) + self.try_into_analyze_physical_plan(analyze, ctx, proto_converter) } PhysicalPlanType::JsonSink(sink) => { - self.try_into_json_sink_physical_plan(sink, ctx, codec, proto_converter) + self.try_into_json_sink_physical_plan(sink, ctx, proto_converter) } PhysicalPlanType::CsvSink(sink) => { - self.try_into_csv_sink_physical_plan(sink, ctx, codec, proto_converter) + self.try_into_csv_sink_physical_plan(sink, ctx, proto_converter) } #[cfg_attr(not(feature = "parquet"), allow(unused_variables))] - PhysicalPlanType::ParquetSink(sink) => self - .try_into_parquet_sink_physical_plan(sink, ctx, codec, proto_converter), + PhysicalPlanType::ParquetSink(sink) => { + self.try_into_parquet_sink_physical_plan(sink, ctx, proto_converter) + } PhysicalPlanType::Unnest(unnest) => { - self.try_into_unnest_physical_plan(unnest, ctx, codec, proto_converter) + self.try_into_unnest_physical_plan(unnest, ctx, proto_converter) + } + PhysicalPlanType::Cooperative(cooperative) => { + self.try_into_cooperative_physical_plan(cooperative, ctx, proto_converter) } - PhysicalPlanType::Cooperative(cooperative) => self - .try_into_cooperative_physical_plan( - cooperative, - ctx, - codec, - proto_converter, - ), PhysicalPlanType::GenerateSeries(generate_series) => { self.try_into_generate_series_physical_plan(generate_series) } PhysicalPlanType::SortMergeJoin(sort_join) => { - self.try_into_sort_join(sort_join, ctx, codec, proto_converter) + self.try_into_sort_join(sort_join, ctx, proto_converter) + } + PhysicalPlanType::AsyncFunc(async_func) => { + self.try_into_async_func_physical_plan(async_func, ctx, proto_converter) } - PhysicalPlanType::AsyncFunc(async_func) => self - .try_into_async_func_physical_plan( - async_func, - ctx, - codec, - proto_converter, - ), PhysicalPlanType::Buffer(buffer) => { - self.try_into_buffer_physical_plan(buffer, ctx, codec, proto_converter) + self.try_into_buffer_physical_plan(buffer, ctx, proto_converter) + } + PhysicalPlanType::ScalarSubquery(sq) => { + self.try_into_scalar_subquery_physical_plan(sq, ctx, proto_converter) } } } @@ -569,6 +608,14 @@ impl protobuf::PhysicalPlanNode { ); } + if let Some(exec) = plan.downcast_ref::() { + return protobuf::PhysicalPlanNode::try_from_scalar_subquery_exec( + exec, + codec, + proto_converter, + ); + } + let mut buf: Vec = vec![]; match codec.try_encode(Arc::clone(&plan_clone), &mut buf) { Ok(_) => { @@ -602,9 +649,7 @@ impl protobuf::PhysicalPlanNode { fn try_into_explain_physical_plan( &self, explain: &protobuf::ExplainExecNode, - _ctx: &TaskContext, - - _codec: &dyn PhysicalExtensionCodec, + _ctx: &PhysicalPlanDecodeContext<'_>, _proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { Ok(Arc::new(ExplainExec::new( @@ -621,13 +666,11 @@ impl protobuf::PhysicalPlanNode { fn try_into_projection_physical_plan( &self, projection: &protobuf::ProjectionExecNode, - ctx: &TaskContext, - - codec: &dyn PhysicalExtensionCodec, + ctx: &PhysicalPlanDecodeContext<'_>, proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let input: Arc = - into_physical_plan(&projection.input, ctx, codec, proto_converter)?; + into_physical_plan(&projection.input, ctx, proto_converter)?; let exprs = projection .expr .iter() @@ -636,9 +679,8 @@ impl protobuf::PhysicalPlanNode { Ok(( proto_converter.proto_to_physical_expr( expr, - ctx, input.schema().as_ref(), - codec, + ctx, )?, name.to_string(), )) @@ -654,24 +696,17 @@ impl protobuf::PhysicalPlanNode { fn try_into_filter_physical_plan( &self, filter: &protobuf::FilterExecNode, - ctx: &TaskContext, - - codec: &dyn PhysicalExtensionCodec, + ctx: &PhysicalPlanDecodeContext<'_>, proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let input: Arc = - into_physical_plan(&filter.input, ctx, codec, proto_converter)?; + into_physical_plan(&filter.input, ctx, proto_converter)?; let predicate = filter .expr .as_ref() .map(|expr| { - proto_converter.proto_to_physical_expr( - expr, - ctx, - input.schema().as_ref(), - codec, - ) + proto_converter.proto_to_physical_expr(expr, input.schema().as_ref(), ctx) }) .transpose()? .ok_or_else(|| { @@ -711,9 +746,7 @@ impl protobuf::PhysicalPlanNode { fn try_into_csv_scan_physical_plan( &self, scan: &protobuf::CsvScanExecNode, - ctx: &TaskContext, - - codec: &dyn PhysicalExtensionCodec, + ctx: &PhysicalPlanDecodeContext<'_>, proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let escape = @@ -755,7 +788,6 @@ impl protobuf::PhysicalPlanNode { let conf = FileScanConfigBuilder::from(parse_protobuf_file_scan_config( scan.base_conf.as_ref().unwrap(), ctx, - codec, proto_converter, source, )?) @@ -767,9 +799,7 @@ impl protobuf::PhysicalPlanNode { fn try_into_json_scan_physical_plan( &self, scan: &protobuf::JsonScanExecNode, - ctx: &TaskContext, - - codec: &dyn PhysicalExtensionCodec, + ctx: &PhysicalPlanDecodeContext<'_>, proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let base_conf = scan.base_conf.as_ref().unwrap(); @@ -777,7 +807,6 @@ impl protobuf::PhysicalPlanNode { let scan_conf = parse_protobuf_file_scan_config( base_conf, ctx, - codec, proto_converter, Arc::new(JsonSource::new(table_schema)), )?; @@ -787,8 +816,7 @@ impl protobuf::PhysicalPlanNode { fn try_into_arrow_scan_physical_plan( &self, scan: &protobuf::ArrowScanExecNode, - ctx: &TaskContext, - codec: &dyn PhysicalExtensionCodec, + ctx: &PhysicalPlanDecodeContext<'_>, proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let base_conf = scan.base_conf.as_ref().ok_or_else(|| { @@ -798,7 +826,6 @@ impl protobuf::PhysicalPlanNode { let scan_conf = parse_protobuf_file_scan_config( base_conf, ctx, - codec, proto_converter, Arc::new(ArrowSource::new_file_source(table_schema)), )?; @@ -809,8 +836,7 @@ impl protobuf::PhysicalPlanNode { fn try_into_parquet_scan_physical_plan( &self, scan: &protobuf::ParquetScanExecNode, - ctx: &TaskContext, - codec: &dyn PhysicalExtensionCodec, + ctx: &PhysicalPlanDecodeContext<'_>, proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { #[cfg(feature = "parquet")] @@ -839,9 +865,8 @@ impl protobuf::PhysicalPlanNode { .map(|expr| { proto_converter.proto_to_physical_expr( expr, - ctx, predicate_schema.as_ref(), - codec, + ctx, ) }) .transpose()?; @@ -857,9 +882,15 @@ impl protobuf::PhysicalPlanNode { false => ObjectStoreUrl::parse(&base_conf.object_store_url)?, true => ObjectStoreUrl::local_filesystem(), }; - let store = ctx.runtime_env().object_store(object_store_url)?; - let metadata_cache = - ctx.runtime_env().cache_manager.get_file_metadata_cache(); + let store = ctx + .task_ctx() + .runtime_env() + .object_store(object_store_url)?; + let metadata_cache = ctx + .task_ctx() + .runtime_env() + .cache_manager + .get_file_metadata_cache(); let reader_factory = Arc::new(CachedParquetFileReaderFactory::new(store, metadata_cache)); @@ -873,7 +904,6 @@ impl protobuf::PhysicalPlanNode { let base_config = parse_protobuf_file_scan_config( base_conf, ctx, - codec, proto_converter, Arc::new(source), )?; @@ -889,8 +919,7 @@ impl protobuf::PhysicalPlanNode { fn try_into_avro_scan_physical_plan( &self, scan: &protobuf::AvroScanExecNode, - ctx: &TaskContext, - codec: &dyn PhysicalExtensionCodec, + ctx: &PhysicalPlanDecodeContext<'_>, proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { #[cfg(feature = "avro")] @@ -900,7 +929,6 @@ impl protobuf::PhysicalPlanNode { let conf = parse_protobuf_file_scan_config( scan.base_conf.as_ref().unwrap(), ctx, - codec, proto_converter, Arc::new(AvroSource::new(table_schema)), )?; @@ -914,9 +942,7 @@ impl protobuf::PhysicalPlanNode { fn try_into_memory_scan_physical_plan( &self, scan: &protobuf::MemoryScanExecNode, - ctx: &TaskContext, - - codec: &dyn PhysicalExtensionCodec, + ctx: &PhysicalPlanDecodeContext<'_>, proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let partitions = scan @@ -947,7 +973,6 @@ impl protobuf::PhysicalPlanNode { &ordering.physical_sort_expr_nodes, ctx, &schema, - codec, proto_converter, )?; sort_information.extend(LexOrdering::new(sort_exprs)); @@ -965,13 +990,11 @@ impl protobuf::PhysicalPlanNode { fn try_into_coalesce_batches_physical_plan( &self, coalesce_batches: &protobuf::CoalesceBatchesExecNode, - ctx: &TaskContext, - - codec: &dyn PhysicalExtensionCodec, + ctx: &PhysicalPlanDecodeContext<'_>, proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let input: Arc = - into_physical_plan(&coalesce_batches.input, ctx, codec, proto_converter)?; + into_physical_plan(&coalesce_batches.input, ctx, proto_converter)?; Ok(Arc::new( #[expect(deprecated)] CoalesceBatchesExec::new(input, coalesce_batches.target_batch_size as usize) @@ -982,13 +1005,11 @@ impl protobuf::PhysicalPlanNode { fn try_into_merge_physical_plan( &self, merge: &protobuf::CoalescePartitionsExecNode, - ctx: &TaskContext, - - codec: &dyn PhysicalExtensionCodec, + ctx: &PhysicalPlanDecodeContext<'_>, proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let input: Arc = - into_physical_plan(&merge.input, ctx, codec, proto_converter)?; + into_physical_plan(&merge.input, ctx, proto_converter)?; Ok(Arc::new( CoalescePartitionsExec::new(input) .with_fetch(merge.fetch.map(|f| f as usize)), @@ -998,18 +1019,15 @@ impl protobuf::PhysicalPlanNode { fn try_into_repartition_physical_plan( &self, repart: &protobuf::RepartitionExecNode, - ctx: &TaskContext, - - codec: &dyn PhysicalExtensionCodec, + ctx: &PhysicalPlanDecodeContext<'_>, proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let input: Arc = - into_physical_plan(&repart.input, ctx, codec, proto_converter)?; + into_physical_plan(&repart.input, ctx, proto_converter)?; let partitioning = parse_protobuf_partitioning( repart.partitioning.as_ref(), ctx, input.schema().as_ref(), - codec, proto_converter, )?; let mut repart_exec = RepartitionExec::try_new(input, partitioning.unwrap())?; @@ -1022,13 +1040,11 @@ impl protobuf::PhysicalPlanNode { fn try_into_global_limit_physical_plan( &self, limit: &protobuf::GlobalLimitExecNode, - ctx: &TaskContext, - - codec: &dyn PhysicalExtensionCodec, + ctx: &PhysicalPlanDecodeContext<'_>, proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let input: Arc = - into_physical_plan(&limit.input, ctx, codec, proto_converter)?; + into_physical_plan(&limit.input, ctx, proto_converter)?; let fetch = if limit.fetch >= 0 { Some(limit.fetch as usize) } else { @@ -1044,26 +1060,22 @@ impl protobuf::PhysicalPlanNode { fn try_into_local_limit_physical_plan( &self, limit: &protobuf::LocalLimitExecNode, - ctx: &TaskContext, - - codec: &dyn PhysicalExtensionCodec, + ctx: &PhysicalPlanDecodeContext<'_>, proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let input: Arc = - into_physical_plan(&limit.input, ctx, codec, proto_converter)?; + into_physical_plan(&limit.input, ctx, proto_converter)?; Ok(Arc::new(LocalLimitExec::new(input, limit.fetch as usize))) } fn try_into_window_physical_plan( &self, window_agg: &protobuf::WindowAggExecNode, - ctx: &TaskContext, - - codec: &dyn PhysicalExtensionCodec, + ctx: &PhysicalPlanDecodeContext<'_>, proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let input: Arc = - into_physical_plan(&window_agg.input, ctx, codec, proto_converter)?; + into_physical_plan(&window_agg.input, ctx, proto_converter)?; let input_schema = input.schema(); let physical_window_expr: Vec> = window_agg @@ -1074,7 +1086,6 @@ impl protobuf::PhysicalPlanNode { window_expr, ctx, input_schema.as_ref(), - codec, proto_converter, ) }) @@ -1084,12 +1095,7 @@ impl protobuf::PhysicalPlanNode { .partition_keys .iter() .map(|expr| { - proto_converter.proto_to_physical_expr( - expr, - ctx, - input.schema().as_ref(), - codec, - ) + proto_converter.proto_to_physical_expr(expr, input.schema().as_ref(), ctx) }) .collect::>>>()?; @@ -1122,13 +1128,11 @@ impl protobuf::PhysicalPlanNode { fn try_into_aggregate_physical_plan( &self, hash_agg: &protobuf::AggregateExecNode, - ctx: &TaskContext, - - codec: &dyn PhysicalExtensionCodec, + ctx: &PhysicalPlanDecodeContext<'_>, proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let input: Arc = - into_physical_plan(&hash_agg.input, ctx, codec, proto_converter)?; + into_physical_plan(&hash_agg.input, ctx, proto_converter)?; let mode = protobuf::AggregateMode::try_from(hash_agg.mode).map_err(|_| { proto_error(format!( "Received a AggregateNode message with unknown AggregateMode {}", @@ -1154,7 +1158,7 @@ impl protobuf::PhysicalPlanNode { .zip(hash_agg.group_expr_name.iter()) .map(|(expr, name)| { proto_converter - .proto_to_physical_expr(expr, ctx, input.schema().as_ref(), codec) + .proto_to_physical_expr(expr, input.schema().as_ref(), ctx) .map(|expr| (expr, name.to_string())) }) .collect::, _>>()?; @@ -1165,7 +1169,7 @@ impl protobuf::PhysicalPlanNode { .zip(hash_agg.group_expr_name.iter()) .map(|(expr, name)| { proto_converter - .proto_to_physical_expr(expr, ctx, input.schema().as_ref(), codec) + .proto_to_physical_expr(expr, input.schema().as_ref(), ctx) .map(|expr| (expr, name.to_string())) }) .collect::, _>>()?; @@ -1194,12 +1198,7 @@ impl protobuf::PhysicalPlanNode { expr.expr .as_ref() .map(|e| { - proto_converter.proto_to_physical_expr( - e, - ctx, - &physical_schema, - codec, - ) + proto_converter.proto_to_physical_expr(e, &physical_schema, ctx) }) .transpose() }) @@ -1222,9 +1221,8 @@ impl protobuf::PhysicalPlanNode { .map(|e| { proto_converter.proto_to_physical_expr( e, - ctx, &physical_schema, - codec, + ctx, ) }) .collect::>>()?; @@ -1236,7 +1234,6 @@ impl protobuf::PhysicalPlanNode { e, ctx, &physical_schema, - codec, proto_converter, ) }) @@ -1248,11 +1245,14 @@ impl protobuf::PhysicalPlanNode { AggregateFunction::UserDefinedAggrFunction(udaf_name) => { let agg_udf = match &agg_node.fun_definition { Some(buf) => { - codec.try_decode_udaf(udaf_name, buf)? + ctx.codec().try_decode_udaf(udaf_name, buf)? } - None => ctx.udaf(udaf_name).or_else(|_| { - codec.try_decode_udaf(udaf_name, &[]) - })?, + None => ctx.task_ctx().udaf(udaf_name).or_else( + |_| { + ctx.codec() + .try_decode_udaf(udaf_name, &[]) + }, + )?, }; AggregateExprBuilder::new(agg_udf, input_phy_expr) @@ -1304,15 +1304,13 @@ impl protobuf::PhysicalPlanNode { fn try_into_hash_join_physical_plan( &self, hashjoin: &protobuf::HashJoinExecNode, - ctx: &TaskContext, - - codec: &dyn PhysicalExtensionCodec, + ctx: &PhysicalPlanDecodeContext<'_>, proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let left: Arc = - into_physical_plan(&hashjoin.left, ctx, codec, proto_converter)?; + into_physical_plan(&hashjoin.left, ctx, proto_converter)?; let right: Arc = - into_physical_plan(&hashjoin.right, ctx, codec, proto_converter)?; + into_physical_plan(&hashjoin.right, ctx, proto_converter)?; let left_schema = left.schema(); let right_schema = right.schema(); let on: Vec<(PhysicalExprRef, PhysicalExprRef)> = hashjoin @@ -1321,15 +1319,13 @@ impl protobuf::PhysicalPlanNode { .map(|col| { let left = proto_converter.proto_to_physical_expr( &col.left.clone().unwrap(), - ctx, left_schema.as_ref(), - codec, + ctx, )?; let right = proto_converter.proto_to_physical_expr( &col.right.clone().unwrap(), - ctx, right_schema.as_ref(), - codec, + ctx, )?; Ok((left, right)) }) @@ -1362,8 +1358,8 @@ impl protobuf::PhysicalPlanNode { f.expression.as_ref().ok_or_else(|| { proto_error("Unexpected empty filter expression") })?, - ctx, &schema, - codec, + &schema, + ctx, )?; let column_indices = f.column_indices .iter() @@ -1424,13 +1420,11 @@ impl protobuf::PhysicalPlanNode { fn try_into_symmetric_hash_join_physical_plan( &self, sym_join: &protobuf::SymmetricHashJoinExecNode, - ctx: &TaskContext, - - codec: &dyn PhysicalExtensionCodec, + ctx: &PhysicalPlanDecodeContext<'_>, proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { - let left = into_physical_plan(&sym_join.left, ctx, codec, proto_converter)?; - let right = into_physical_plan(&sym_join.right, ctx, codec, proto_converter)?; + let left = into_physical_plan(&sym_join.left, ctx, proto_converter)?; + let right = into_physical_plan(&sym_join.right, ctx, proto_converter)?; let left_schema = left.schema(); let right_schema = right.schema(); let on = sym_join @@ -1439,15 +1433,13 @@ impl protobuf::PhysicalPlanNode { .map(|col| { let left = proto_converter.proto_to_physical_expr( &col.left.clone().unwrap(), - ctx, left_schema.as_ref(), - codec, + ctx, )?; let right = proto_converter.proto_to_physical_expr( &col.right.clone().unwrap(), - ctx, right_schema.as_ref(), - codec, + ctx, )?; Ok((left, right)) }) @@ -1480,8 +1472,8 @@ impl protobuf::PhysicalPlanNode { f.expression.as_ref().ok_or_else(|| { proto_error("Unexpected empty filter expression") })?, - ctx, &schema, - codec, + &schema, + ctx, )?; let column_indices = f.column_indices .iter() @@ -1507,7 +1499,6 @@ impl protobuf::PhysicalPlanNode { &sym_join.left_sort_exprs, ctx, &left_schema, - codec, proto_converter, )?; let left_sort_exprs = LexOrdering::new(left_sort_exprs); @@ -1516,7 +1507,6 @@ impl protobuf::PhysicalPlanNode { &sym_join.right_sort_exprs, ctx, &right_schema, - codec, proto_converter, )?; let right_sort_exprs = LexOrdering::new(right_sort_exprs); @@ -1555,14 +1545,12 @@ impl protobuf::PhysicalPlanNode { fn try_into_union_physical_plan( &self, union: &protobuf::UnionExecNode, - ctx: &TaskContext, - - codec: &dyn PhysicalExtensionCodec, + ctx: &PhysicalPlanDecodeContext<'_>, proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let mut inputs: Vec> = vec![]; for input in &union.inputs { - inputs.push(proto_converter.proto_to_execution_plan(ctx, codec, input)?); + inputs.push(proto_converter.proto_to_execution_plan(input, ctx)?); } UnionExec::try_new(inputs) } @@ -1570,14 +1558,12 @@ impl protobuf::PhysicalPlanNode { fn try_into_interleave_physical_plan( &self, interleave: &protobuf::InterleaveExecNode, - ctx: &TaskContext, - - codec: &dyn PhysicalExtensionCodec, + ctx: &PhysicalPlanDecodeContext<'_>, proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let mut inputs: Vec> = vec![]; for input in &interleave.inputs { - inputs.push(proto_converter.proto_to_execution_plan(ctx, codec, input)?); + inputs.push(proto_converter.proto_to_execution_plan(input, ctx)?); } Ok(Arc::new(InterleaveExec::try_new(inputs)?)) } @@ -1585,24 +1571,20 @@ impl protobuf::PhysicalPlanNode { fn try_into_cross_join_physical_plan( &self, crossjoin: &protobuf::CrossJoinExecNode, - ctx: &TaskContext, - - codec: &dyn PhysicalExtensionCodec, + ctx: &PhysicalPlanDecodeContext<'_>, proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let left: Arc = - into_physical_plan(&crossjoin.left, ctx, codec, proto_converter)?; + into_physical_plan(&crossjoin.left, ctx, proto_converter)?; let right: Arc = - into_physical_plan(&crossjoin.right, ctx, codec, proto_converter)?; + into_physical_plan(&crossjoin.right, ctx, proto_converter)?; Ok(Arc::new(CrossJoinExec::new(left, right))) } fn try_into_empty_physical_plan( &self, empty: &protobuf::EmptyExecNode, - _ctx: &TaskContext, - - _codec: &dyn PhysicalExtensionCodec, + _ctx: &PhysicalPlanDecodeContext<'_>, _proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let schema = Arc::new(convert_required!(empty.schema)?); @@ -1612,9 +1594,7 @@ impl protobuf::PhysicalPlanNode { fn try_into_placeholder_row_physical_plan( &self, placeholder: &protobuf::PlaceholderRowExecNode, - _ctx: &TaskContext, - - _codec: &dyn PhysicalExtensionCodec, + _ctx: &PhysicalPlanDecodeContext<'_>, ) -> Result> { let schema = Arc::new(convert_required!(placeholder.schema)?); Ok(Arc::new(PlaceholderRowExec::new(schema))) @@ -1623,12 +1603,10 @@ impl protobuf::PhysicalPlanNode { fn try_into_sort_physical_plan( &self, sort: &protobuf::SortExecNode, - ctx: &TaskContext, - - codec: &dyn PhysicalExtensionCodec, + ctx: &PhysicalPlanDecodeContext<'_>, proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { - let input = into_physical_plan(&sort.input, ctx, codec, proto_converter)?; + let input = into_physical_plan(&sort.input, ctx, proto_converter)?; let exprs = sort .expr .iter() @@ -1649,7 +1627,11 @@ impl protobuf::PhysicalPlanNode { })? .as_ref(); Ok(PhysicalSortExpr { - expr: proto_converter.proto_to_physical_expr(expr, ctx, input.schema().as_ref(), codec)?, + expr: proto_converter.proto_to_physical_expr( + expr, + input.schema().as_ref(), + ctx, + )?, options: SortOptions { descending: !sort_expr.asc, nulls_first: sort_expr.nulls_first, @@ -1676,12 +1658,10 @@ impl protobuf::PhysicalPlanNode { fn try_into_sort_preserving_merge_physical_plan( &self, sort: &protobuf::SortPreservingMergeExecNode, - ctx: &TaskContext, - - codec: &dyn PhysicalExtensionCodec, + ctx: &PhysicalPlanDecodeContext<'_>, proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { - let input = into_physical_plan(&sort.input, ctx, codec, proto_converter)?; + let input = into_physical_plan(&sort.input, ctx, proto_converter)?; let exprs = sort .expr .iter() @@ -1704,9 +1684,8 @@ impl protobuf::PhysicalPlanNode { Ok(PhysicalSortExpr { expr: proto_converter.proto_to_physical_expr( expr, - ctx, input.schema().as_ref(), - codec, + ctx, )?, options: SortOptions { descending: !sort_expr.asc, @@ -1730,18 +1709,18 @@ impl protobuf::PhysicalPlanNode { fn try_into_extension_physical_plan( &self, extension: &protobuf::PhysicalExtensionNode, - ctx: &TaskContext, - - codec: &dyn PhysicalExtensionCodec, + ctx: &PhysicalPlanDecodeContext<'_>, proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let inputs: Vec> = extension .inputs .iter() - .map(|i| proto_converter.proto_to_execution_plan(ctx, codec, i)) + .map(|i| proto_converter.proto_to_execution_plan(i, ctx)) .collect::>()?; - let extension_node = codec.try_decode(extension.node.as_slice(), &inputs, ctx)?; + let extension_node = + ctx.codec() + .try_decode(extension.node.as_slice(), &inputs, ctx.task_ctx())?; Ok(extension_node) } @@ -1749,15 +1728,13 @@ impl protobuf::PhysicalPlanNode { fn try_into_nested_loop_join_physical_plan( &self, join: &protobuf::NestedLoopJoinExecNode, - ctx: &TaskContext, - - codec: &dyn PhysicalExtensionCodec, + ctx: &PhysicalPlanDecodeContext<'_>, proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let left: Arc = - into_physical_plan(&join.left, ctx, codec, proto_converter)?; + into_physical_plan(&join.left, ctx, proto_converter)?; let right: Arc = - into_physical_plan(&join.right, ctx, codec, proto_converter)?; + into_physical_plan(&join.right, ctx, proto_converter)?; let join_type = protobuf::JoinType::try_from(join.join_type).map_err(|_| { proto_error(format!( "Received a NestedLoopJoinExecNode message with unknown JoinType {}", @@ -1774,12 +1751,13 @@ impl protobuf::PhysicalPlanNode { .ok_or_else(|| proto_error("Missing JoinFilter schema"))? .try_into()?; - let expression = proto_converter.proto_to_physical_expr( + let expression = proto_converter + .proto_to_physical_expr( f.expression.as_ref().ok_or_else(|| { proto_error("Unexpected empty filter expression") })?, - ctx, &schema, - codec, + &schema, + ctx, )?; let column_indices = f.column_indices .iter() @@ -1824,13 +1802,11 @@ impl protobuf::PhysicalPlanNode { fn try_into_analyze_physical_plan( &self, analyze: &protobuf::AnalyzeExecNode, - ctx: &TaskContext, - - codec: &dyn PhysicalExtensionCodec, + ctx: &PhysicalPlanDecodeContext<'_>, proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let input: Arc = - into_physical_plan(&analyze.input, ctx, codec, proto_converter)?; + into_physical_plan(&analyze.input, ctx, proto_converter)?; let metric_categories = if analyze.has_metric_categories { let cats: Result> = analyze .metric_categories @@ -1854,12 +1830,10 @@ impl protobuf::PhysicalPlanNode { fn try_into_json_sink_physical_plan( &self, sink: &protobuf::JsonSinkExecNode, - ctx: &TaskContext, - - codec: &dyn PhysicalExtensionCodec, + ctx: &PhysicalPlanDecodeContext<'_>, proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { - let input = into_physical_plan(&sink.input, ctx, codec, proto_converter)?; + let input = into_physical_plan(&sink.input, ctx, proto_converter)?; let data_sink: JsonSink = sink .sink @@ -1875,7 +1849,6 @@ impl protobuf::PhysicalPlanNode { &collection.physical_sort_expr_nodes, ctx, &sink_schema, - codec, proto_converter, ) .map(|sort_exprs| { @@ -1894,12 +1867,10 @@ impl protobuf::PhysicalPlanNode { fn try_into_csv_sink_physical_plan( &self, sink: &protobuf::CsvSinkExecNode, - ctx: &TaskContext, - - codec: &dyn PhysicalExtensionCodec, + ctx: &PhysicalPlanDecodeContext<'_>, proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { - let input = into_physical_plan(&sink.input, ctx, codec, proto_converter)?; + let input = into_physical_plan(&sink.input, ctx, proto_converter)?; let data_sink: CsvSink = sink .sink @@ -1915,7 +1886,6 @@ impl protobuf::PhysicalPlanNode { &collection.physical_sort_expr_nodes, ctx, &sink_schema, - codec, proto_converter, ) .map(|sort_exprs| { @@ -1935,14 +1905,12 @@ impl protobuf::PhysicalPlanNode { fn try_into_parquet_sink_physical_plan( &self, sink: &protobuf::ParquetSinkExecNode, - ctx: &TaskContext, - - codec: &dyn PhysicalExtensionCodec, + ctx: &PhysicalPlanDecodeContext<'_>, proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { #[cfg(feature = "parquet")] { - let input = into_physical_plan(&sink.input, ctx, codec, proto_converter)?; + let input = into_physical_plan(&sink.input, ctx, proto_converter)?; let data_sink: ParquetSink = sink .sink @@ -1958,7 +1926,6 @@ impl protobuf::PhysicalPlanNode { &collection.physical_sort_expr_nodes, ctx, &sink_schema, - codec, proto_converter, ) .map(|sort_exprs| { @@ -1980,12 +1947,10 @@ impl protobuf::PhysicalPlanNode { fn try_into_unnest_physical_plan( &self, unnest: &protobuf::UnnestExecNode, - ctx: &TaskContext, - - codec: &dyn PhysicalExtensionCodec, + ctx: &PhysicalPlanDecodeContext<'_>, proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { - let input = into_physical_plan(&unnest.input, ctx, codec, proto_converter)?; + let input = into_physical_plan(&unnest.input, ctx, proto_converter)?; Ok(Arc::new(UnnestExec::new( input, @@ -2012,14 +1977,12 @@ impl protobuf::PhysicalPlanNode { fn try_into_sort_join( &self, sort_join: &SortMergeJoinExecNode, - ctx: &TaskContext, - - codec: &dyn PhysicalExtensionCodec, + ctx: &PhysicalPlanDecodeContext<'_>, proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { - let left = into_physical_plan(&sort_join.left, ctx, codec, proto_converter)?; + let left = into_physical_plan(&sort_join.left, ctx, proto_converter)?; let left_schema = left.schema(); - let right = into_physical_plan(&sort_join.right, ctx, codec, proto_converter)?; + let right = into_physical_plan(&sort_join.right, ctx, proto_converter)?; let right_schema = right.schema(); let filter = sort_join @@ -2036,9 +1999,8 @@ impl protobuf::PhysicalPlanNode { f.expression.as_ref().ok_or_else(|| { proto_error("Unexpected empty filter expression") })?, - ctx, &schema, - codec, + ctx, )?; let column_indices = f .column_indices @@ -2097,15 +2059,13 @@ impl protobuf::PhysicalPlanNode { .map(|col| { let left = proto_converter.proto_to_physical_expr( &col.left.clone().unwrap(), - ctx, left_schema.as_ref(), - codec, + ctx, )?; let right = proto_converter.proto_to_physical_expr( &col.right.clone().unwrap(), - ctx, right_schema.as_ref(), - codec, + ctx, )?; Ok((left, right)) }) @@ -2190,24 +2150,21 @@ impl protobuf::PhysicalPlanNode { fn try_into_cooperative_physical_plan( &self, field_stream: &protobuf::CooperativeExecNode, - ctx: &TaskContext, - - codec: &dyn PhysicalExtensionCodec, + ctx: &PhysicalPlanDecodeContext<'_>, proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { - let input = into_physical_plan(&field_stream.input, ctx, codec, proto_converter)?; + let input = into_physical_plan(&field_stream.input, ctx, proto_converter)?; Ok(Arc::new(CooperativeExec::new(input))) } fn try_into_async_func_physical_plan( &self, async_func: &protobuf::AsyncFuncExecNode, - ctx: &TaskContext, - codec: &dyn PhysicalExtensionCodec, + ctx: &PhysicalPlanDecodeContext<'_>, proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let input: Arc = - into_physical_plan(&async_func.input, ctx, codec, proto_converter)?; + into_physical_plan(&async_func.input, ctx, proto_converter)?; if async_func.async_exprs.len() != async_func.async_expr_names.len() { return internal_err!( @@ -2222,9 +2179,8 @@ impl protobuf::PhysicalPlanNode { .map(|(expr, name)| { let physical_expr = proto_converter.proto_to_physical_expr( expr, - ctx, input.schema().as_ref(), - codec, + ctx, )?; Ok(Arc::new(AsyncFuncExpr::try_new( @@ -2241,16 +2197,49 @@ impl protobuf::PhysicalPlanNode { fn try_into_buffer_physical_plan( &self, buffer: &protobuf::BufferExecNode, - ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + ctx: &PhysicalPlanDecodeContext<'_>, proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let input: Arc = - into_physical_plan(&buffer.input, ctx, extension_codec, proto_converter)?; + into_physical_plan(&buffer.input, ctx, proto_converter)?; Ok(Arc::new(BufferExec::new(input, buffer.capacity as usize))) } + fn try_into_scalar_subquery_physical_plan( + &self, + sq: &protobuf::ScalarSubqueryExecNode, + ctx: &PhysicalPlanDecodeContext<'_>, + proto_converter: &dyn PhysicalProtoConverterExtension, + ) -> Result> { + // First, deserialize the main input plan. We set up the subquery results + // container first, so that ScalarSubqueryExpr nodes can reference it. + let subquery_results = ScalarSubqueryResults::new(sq.subqueries.len()); + let input_ctx = ctx.with_scalar_subquery_results(subquery_results.clone()); + let input = into_physical_plan(&sq.input, &input_ctx, proto_converter)?; + + // Now deserialize the subquery children. + let subqueries: Vec = sq + .subqueries + .iter() + .enumerate() + .map(|(index, sq_plan)| { + let plan = + sq_plan.try_into_physical_plan_with_context(ctx, proto_converter)?; + Ok(ScalarSubqueryLink { + plan, + index: SubqueryIndex::new(index), + }) + }) + .collect::>>()?; + + Ok(Arc::new(ScalarSubqueryExec::new( + input, + subqueries, + subquery_results, + ))) + } + fn try_from_explain_exec( exec: &ExplainExec, _codec: &dyn PhysicalExtensionCodec, @@ -3643,6 +3632,38 @@ impl protobuf::PhysicalPlanNode { ))), }) } + + fn try_from_scalar_subquery_exec( + exec: &ScalarSubqueryExec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, + ) -> Result { + let input = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( + Arc::clone(exec.input()), + codec, + proto_converter, + )?; + let subqueries = exec + .subqueries() + .iter() + .map(|sq| { + protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( + Arc::clone(&sq.plan), + codec, + proto_converter, + ) + }) + .collect::>>()?; + + Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::ScalarSubquery(Box::new( + protobuf::ScalarSubqueryExecNode { + input: Some(Box::new(input)), + subqueries, + }, + ))), + }) + } } pub trait AsExecutionPlan: Debug + Send + Sync + Clone { @@ -3751,11 +3772,21 @@ impl PhysicalExtensionCodec for DefaultPhysicalExtensionCodec { pub trait PhysicalProtoConverterExtension { fn proto_to_execution_plan( &self, - ctx: &TaskContext, - codec: &dyn PhysicalExtensionCodec, proto: &protobuf::PhysicalPlanNode, + ctx: &PhysicalPlanDecodeContext<'_>, ) -> Result>; + fn default_proto_to_execution_plan( + &self, + proto: &protobuf::PhysicalPlanNode, + ctx: &PhysicalPlanDecodeContext<'_>, + ) -> Result> + where + Self: Sized, + { + proto.try_into_physical_plan_with_context(ctx, self) + } + fn execution_plan_to_proto( &self, plan: &Arc, @@ -3765,11 +3796,22 @@ pub trait PhysicalProtoConverterExtension { fn proto_to_physical_expr( &self, proto: &protobuf::PhysicalExprNode, - ctx: &TaskContext, input_schema: &Schema, - codec: &dyn PhysicalExtensionCodec, + ctx: &PhysicalPlanDecodeContext<'_>, ) -> Result>; + fn default_proto_to_physical_expr( + &self, + proto: &protobuf::PhysicalExprNode, + input_schema: &Schema, + ctx: &PhysicalPlanDecodeContext<'_>, + ) -> Result> + where + Self: Sized, + { + parse_physical_expr_with_converter(proto, input_schema, ctx, self) + } + fn physical_expr_to_proto( &self, expr: &Arc, @@ -3790,15 +3832,16 @@ struct DataEncoderTuple { pub blob: Vec, } -pub struct DefaultPhysicalProtoConverter; +#[derive(Default)] +pub struct DefaultPhysicalProtoConverter {} + impl PhysicalProtoConverterExtension for DefaultPhysicalProtoConverter { fn proto_to_execution_plan( &self, - ctx: &TaskContext, - codec: &dyn PhysicalExtensionCodec, proto: &protobuf::PhysicalPlanNode, + ctx: &PhysicalPlanDecodeContext<'_>, ) -> Result> { - proto.try_into_physical_plan_with_converter(ctx, codec, self) + proto.try_into_physical_plan_with_context(ctx, self) } fn execution_plan_to_proto( @@ -3819,15 +3862,14 @@ impl PhysicalProtoConverterExtension for DefaultPhysicalProtoConverter { fn proto_to_physical_expr( &self, proto: &protobuf::PhysicalExprNode, - ctx: &TaskContext, input_schema: &Schema, - codec: &dyn PhysicalExtensionCodec, + ctx: &PhysicalPlanDecodeContext<'_>, ) -> Result> where Self: Sized, { // Default implementation calls the free function - parse_physical_expr_with_converter(proto, ctx, input_schema, codec, self) + parse_physical_expr_with_converter(proto, input_schema, ctx, self) } fn physical_expr_to_proto( @@ -3857,9 +3899,8 @@ impl DeduplicatingSerializer { impl PhysicalProtoConverterExtension for DeduplicatingSerializer { fn proto_to_execution_plan( &self, - _ctx: &TaskContext, - _codec: &dyn PhysicalExtensionCodec, _proto: &protobuf::PhysicalPlanNode, + _ctx: &PhysicalPlanDecodeContext<'_>, ) -> Result> { internal_err!("DeduplicatingSerializer cannot deserialize execution plans") } @@ -3882,9 +3923,8 @@ impl PhysicalProtoConverterExtension for DeduplicatingSerializer { fn proto_to_physical_expr( &self, _proto: &protobuf::PhysicalExprNode, - _ctx: &TaskContext, _input_schema: &Schema, - _codec: &dyn PhysicalExtensionCodec, + _ctx: &PhysicalPlanDecodeContext<'_>, ) -> Result> where Self: Sized, @@ -3924,11 +3964,10 @@ struct DeduplicatingDeserializer { impl PhysicalProtoConverterExtension for DeduplicatingDeserializer { fn proto_to_execution_plan( &self, - ctx: &TaskContext, - codec: &dyn PhysicalExtensionCodec, proto: &protobuf::PhysicalPlanNode, + ctx: &PhysicalPlanDecodeContext<'_>, ) -> Result> { - proto.try_into_physical_plan_with_converter(ctx, codec, self) + proto.try_into_physical_plan_with_context(ctx, self) } fn execution_plan_to_proto( @@ -3945,9 +3984,8 @@ impl PhysicalProtoConverterExtension for DeduplicatingDeserializer { fn proto_to_physical_expr( &self, proto: &protobuf::PhysicalExprNode, - ctx: &TaskContext, input_schema: &Schema, - codec: &dyn PhysicalExtensionCodec, + ctx: &PhysicalPlanDecodeContext<'_>, ) -> Result> where Self: Sized, @@ -3958,17 +3996,12 @@ impl PhysicalProtoConverterExtension for DeduplicatingDeserializer { return Ok(Arc::clone(cached)); } // Deserialize and cache - let expr = parse_physical_expr_with_converter( - proto, - ctx, - input_schema, - codec, - self, - )?; + let expr = + parse_physical_expr_with_converter(proto, input_schema, ctx, self)?; self.cache.borrow_mut().insert(expr_id, Arc::clone(&expr)); Ok(expr) } else { - parse_physical_expr_with_converter(proto, ctx, input_schema, codec, self) + parse_physical_expr_with_converter(proto, input_schema, ctx, self) } } @@ -4002,12 +4035,11 @@ pub struct DeduplicatingProtoConverter {} impl PhysicalProtoConverterExtension for DeduplicatingProtoConverter { fn proto_to_execution_plan( &self, - ctx: &TaskContext, - codec: &dyn PhysicalExtensionCodec, proto: &protobuf::PhysicalPlanNode, + ctx: &PhysicalPlanDecodeContext<'_>, ) -> Result> { let deserializer = DeduplicatingDeserializer::default(); - proto.try_into_physical_plan_with_converter(ctx, codec, &deserializer) + proto.try_into_physical_plan_with_context(ctx, &deserializer) } fn execution_plan_to_proto( @@ -4029,15 +4061,14 @@ impl PhysicalProtoConverterExtension for DeduplicatingProtoConverter { fn proto_to_physical_expr( &self, proto: &protobuf::PhysicalExprNode, - ctx: &TaskContext, input_schema: &Schema, - codec: &dyn PhysicalExtensionCodec, + ctx: &PhysicalPlanDecodeContext<'_>, ) -> Result> where Self: Sized, { let deserializer = DeduplicatingDeserializer::default(); - deserializer.proto_to_physical_expr(proto, ctx, input_schema, codec) + deserializer.proto_to_physical_expr(proto, input_schema, ctx) } fn physical_expr_to_proto( @@ -4151,12 +4182,11 @@ impl PhysicalExtensionCodec for ComposedPhysicalExtensionCodec { fn into_physical_plan( node: &Option>, - ctx: &TaskContext, - codec: &dyn PhysicalExtensionCodec, + ctx: &PhysicalPlanDecodeContext<'_>, proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { if let Some(field) = node { - proto_converter.proto_to_execution_plan(ctx, codec, field) + proto_converter.proto_to_execution_plan(field, ctx) } else { Err(proto_error("Missing required field in protobuf")) } diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index f0eb6d27aac30..cb69b913b5e8b 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -32,6 +32,7 @@ use datafusion_datasource_json::file_format::JsonSink; use datafusion_datasource_parquet::file_format::ParquetSink; use datafusion_expr::WindowFrame; use datafusion_physical_expr::ScalarFunctionExpr; +use datafusion_physical_expr::scalar_subquery::ScalarSubqueryExpr; use datafusion_physical_expr::window::{SlidingAggregateWindowExpr, StandardWindowExpr}; use datafusion_physical_expr_common::physical_expr::snapshot_physical_expr; use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; @@ -526,6 +527,17 @@ pub fn serialize_physical_expr_with_converter( }, )), }) + } else if let Some(expr) = expr.downcast_ref::() { + Ok(protobuf::PhysicalExprNode { + expr_id: None, + expr_type: Some(protobuf::physical_expr_node::ExprType::ScalarSubquery( + protobuf::PhysicalScalarSubqueryExprNode { + data_type: Some(expr.data_type().try_into()?), + nullable: expr.nullable(), + index: expr.index().as_usize() as u32, + }, + )), + }) } else { let mut buf: Vec = vec![]; match codec.try_encode_expr(value, &mut buf) { diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 6de9dd4caa9b4..c0570881e37a3 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -134,7 +134,8 @@ fn roundtrip_expr_test_with_codec( ) { let proto: protobuf::LogicalExprNode = serialize_expr(&initial_struct, codec) .unwrap_or_else(|e| panic!("Error serializing expression: {e:?}")); - let round_trip: Expr = from_proto::parse_expr(&proto, &ctx, codec).unwrap(); + let round_trip: Expr = + from_proto::parse_expr(&proto, ctx.task_ctx().as_ref(), codec).unwrap(); assert_eq!(format!("{:?}", &initial_struct), format!("{round_trip:?}")); @@ -2571,7 +2572,8 @@ fn roundtrip_scalar_udf_extension_codec() { let ctx = SessionContext::new(); let proto = serialize_expr(&test_expr, &UDFExtensionCodec).expect("serialize expr"); let round_trip = - from_proto::parse_expr(&proto, &ctx, &UDFExtensionCodec).expect("parse expr"); + from_proto::parse_expr(&proto, ctx.task_ctx().as_ref(), &UDFExtensionCodec) + .expect("parse expr"); assert_eq!(format!("{:?}", &test_expr), format!("{round_trip:?}")); roundtrip_json_test(&proto); @@ -2584,7 +2586,8 @@ fn roundtrip_aggregate_udf_extension_codec() { let ctx = SessionContext::new(); let proto = serialize_expr(&test_expr, &UDFExtensionCodec).expect("serialize expr"); let round_trip = - from_proto::parse_expr(&proto, &ctx, &UDFExtensionCodec).expect("parse expr"); + from_proto::parse_expr(&proto, ctx.task_ctx().as_ref(), &UDFExtensionCodec) + .expect("parse expr"); assert_eq!(format!("{:?}", &test_expr), format!("{round_trip:?}")); roundtrip_json_test(&proto); diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 3f1c306603bc1..e7d38b57a1522 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -74,6 +74,9 @@ use datafusion::physical_plan::metrics::MetricType; use datafusion::physical_plan::placeholder_row::PlaceholderRowExec; use datafusion::physical_plan::projection::{ProjectionExec, ProjectionExpr}; use datafusion::physical_plan::repartition::RepartitionExec; +use datafusion::physical_plan::scalar_subquery::{ + ScalarSubqueryExec, ScalarSubqueryLink, +}; use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::physical_plan::union::{InterleaveExec, UnionExec}; use datafusion::physical_plan::unnest::{ListUnnest, UnnestExec}; @@ -102,6 +105,7 @@ use datafusion_expr::{ Accumulator, AccumulatorFactoryFunction, AggregateUDF, ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, SimpleAggregateUDF, WindowFrame, WindowFrameBound, WindowUDF, + execution_props::{ScalarSubqueryResults, SubqueryIndex}, }; use datafusion_functions_aggregate::approx_percentile_cont::approx_percentile_cont_udaf; use datafusion_functions_aggregate::array_agg::array_agg_udaf; @@ -109,15 +113,15 @@ use datafusion_functions_aggregate::average::avg_udaf; use datafusion_functions_aggregate::min_max::max_udaf; use datafusion_functions_aggregate::nth_value::nth_value_udaf; use datafusion_functions_aggregate::string_agg::string_agg_udaf; +use datafusion_physical_expr::scalar_subquery::ScalarSubqueryExpr; use datafusion_proto::bytes::{ physical_plan_from_bytes_with_proto_converter, physical_plan_to_bytes_with_proto_converter, }; -use datafusion_proto::physical_plan::from_proto::parse_physical_expr_with_converter; use datafusion_proto::physical_plan::to_proto::serialize_physical_expr_with_converter; use datafusion_proto::physical_plan::{ AsExecutionPlan, DeduplicatingProtoConverter, DefaultPhysicalExtensionCodec, - DefaultPhysicalProtoConverter, PhysicalExtensionCodec, + DefaultPhysicalProtoConverter, PhysicalExtensionCodec, PhysicalPlanDecodeContext, PhysicalProtoConverterExtension, }; use datafusion_proto::protobuf; @@ -2556,9 +2560,8 @@ fn custom_proto_converter_intercepts() -> Result<()> { impl PhysicalProtoConverterExtension for CustomConverterInterceptor { fn proto_to_execution_plan( &self, - ctx: &TaskContext, - codec: &dyn PhysicalExtensionCodec, proto: &protobuf::PhysicalPlanNode, + ctx: &PhysicalPlanDecodeContext<'_>, ) -> Result> { { let mut counter = self @@ -2567,7 +2570,7 @@ fn custom_proto_converter_intercepts() -> Result<()> { .map_err(|err| exec_datafusion_err!("{err}"))?; *counter += 1; } - proto.try_into_physical_plan_with_converter(ctx, codec, self) + self.default_proto_to_execution_plan(proto, ctx) } fn execution_plan_to_proto( @@ -2595,9 +2598,8 @@ fn custom_proto_converter_intercepts() -> Result<()> { fn proto_to_physical_expr( &self, proto: &PhysicalExprNode, - ctx: &TaskContext, input_schema: &Schema, - codec: &dyn PhysicalExtensionCodec, + ctx: &PhysicalPlanDecodeContext<'_>, ) -> Result> where Self: Sized, @@ -2609,7 +2611,7 @@ fn custom_proto_converter_intercepts() -> Result<()> { .map_err(|err| exec_datafusion_err!("{err}"))?; *counter += 1; } - parse_physical_expr_with_converter(proto, ctx, input_schema, codec, self) + self.default_proto_to_physical_expr(proto, input_schema, ctx) } fn physical_expr_to_proto( @@ -2837,14 +2839,11 @@ fn test_backward_compatibility_no_expr_id() -> Result<()> { let ctx = SessionContext::new(); let codec = DefaultPhysicalExtensionCodec {}; let proto_converter = DefaultPhysicalProtoConverter {}; + let task_ctx = ctx.task_ctx(); + let decode_ctx = PhysicalPlanDecodeContext::new(task_ctx.as_ref(), &codec); // Should deserialize without error - let result = proto_converter.proto_to_physical_expr( - &proto, - ctx.task_ctx().as_ref(), - &schema, - &codec, - )?; + let result = proto_converter.proto_to_physical_expr(&proto, &schema, &decode_ctx)?; // Verify the result is correct let col = result.downcast_ref::().expect("Expected Column"); @@ -2964,17 +2963,14 @@ fn test_deduplication_within_expr_deserialization() -> Result<()> { let ctx = SessionContext::new(); let codec = DefaultPhysicalExtensionCodec {}; let proto_converter = DeduplicatingProtoConverter {}; + let task_ctx = ctx.task_ctx(); + let decode_ctx = PhysicalPlanDecodeContext::new(task_ctx.as_ref(), &codec); // Serialize the expression let proto = proto_converter.physical_expr_to_proto(&binary_expr, &codec)?; // First expression deserialization - let expr1 = proto_converter.proto_to_physical_expr( - &proto, - ctx.task_ctx().as_ref(), - &schema, - &codec, - )?; + let expr1 = proto_converter.proto_to_physical_expr(&proto, &schema, &decode_ctx)?; // Check that deduplication worked within the deserialization let binary1 = expr1 @@ -2986,12 +2982,7 @@ fn test_deduplication_within_expr_deserialization() -> Result<()> { ); // Second expression deserialization - let expr2 = proto_converter.proto_to_physical_expr( - &proto, - ctx.task_ctx().as_ref(), - &schema, - &codec, - )?; + let expr2 = proto_converter.proto_to_physical_expr(&proto, &schema, &decode_ctx)?; // Check that the second expression was also deserialized correctly let binary2 = expr2 @@ -3167,6 +3158,243 @@ fn roundtrip_lead_with_default_value() -> Result<()> { )?)) } +/// Verify that ScalarSubqueryExpr nodes in the input plan are connected to the +/// same shared results container as ScalarSubqueryExec after a proto round-trip. +#[test] +fn roundtrip_scalar_subquery_exec() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)])); + let results = ScalarSubqueryResults::new(1); + + // Build the input plan: a filter whose predicate references the + // scalar subquery result via ScalarSubqueryExpr. + let sq_expr = Arc::new(ScalarSubqueryExpr::new( + DataType::Int64, + true, + SubqueryIndex::new(0), + results.clone(), + )); + let predicate = binary(col("a", &schema)?, Operator::Eq, sq_expr, &schema)?; + let filter = + FilterExec::try_new(predicate, Arc::new(EmptyExec::new(schema.clone())))?; + + // Build a trivial subquery plan. + let subquery_plan = + Arc::new(EmptyExec::new(Arc::new(Schema::new(vec![Field::new( + "x", + DataType::Int64, + true, + )])))); + + let exec: Arc = Arc::new(ScalarSubqueryExec::new( + Arc::new(filter), + vec![ScalarSubqueryLink { + plan: subquery_plan, + index: SubqueryIndex::new(0), + }], + results, + )); + + // Perform the round-trip using DeduplicatingProtoConverter, which + // creates a DeduplicatingDeserializer that threads scalar subquery + // results through expression deserialization. + let codec = DefaultPhysicalExtensionCodec {}; + let converter = DeduplicatingProtoConverter {}; + let bytes = physical_plan_to_bytes_with_proto_converter( + Arc::clone(&exec), + &codec, + &converter, + )?; + let ctx = SessionContext::new(); + let deserialized = physical_plan_from_bytes_with_proto_converter( + bytes.as_ref(), + ctx.task_ctx().as_ref(), + &codec, + &converter, + )?; + + // Verify the deserialized ScalarSubqueryExec's results container is + // shared with the ScalarSubqueryExpr in the input plan. + let sq_exec = deserialized + .downcast_ref::() + .expect("expected ScalarSubqueryExec"); + let exec_results = sq_exec.results(); + + // Walk the input plan to find the ScalarSubqueryExpr and verify it + // points to the same results container. + let filter_exec = sq_exec + .input() + .downcast_ref::() + .expect("expected FilterExec"); + let binary_expr = filter_exec + .predicate() + .downcast_ref::() + .expect("expected BinaryExpr"); + let deserialized_sq_expr = binary_expr + .right() + .downcast_ref::() + .expect("expected ScalarSubqueryExpr"); + + assert!( + ScalarSubqueryResults::ptr_eq(exec_results, deserialized_sq_expr.results()), + "ScalarSubqueryExpr should share the same results container as ScalarSubqueryExec" + ); + Ok(()) +} + +/// Verify that nested ScalarSubqueryExec nodes deserialize with distinct +/// scoped results containers, and that each ScalarSubqueryExpr is wired to the +/// container for its own surrounding ScalarSubqueryExec. +#[test] +fn roundtrip_nested_scalar_subquery_exec_scopes_results() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)])); + let subquery_schema = + Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, true)])); + + let inner_results = ScalarSubqueryResults::new(1); + let inner_sq_expr = Arc::new(ScalarSubqueryExpr::new( + DataType::Int64, + true, + SubqueryIndex::new(0), + inner_results.clone(), + )); + let inner_predicate = + binary(col("a", &schema)?, Operator::Eq, inner_sq_expr, &schema)?; + let inner_filter = Arc::new(FilterExec::try_new( + inner_predicate, + Arc::new(EmptyExec::new(schema.clone())), + )?); + let inner_exec: Arc = Arc::new(ScalarSubqueryExec::new( + inner_filter, + vec![ScalarSubqueryLink { + plan: Arc::new(EmptyExec::new(subquery_schema.clone())), + index: SubqueryIndex::new(0), + }], + inner_results, + )); + + let outer_results = ScalarSubqueryResults::new(1); + let outer_sq_expr = Arc::new(ScalarSubqueryExpr::new( + DataType::Int64, + true, + SubqueryIndex::new(0), + outer_results.clone(), + )); + let outer_predicate = + binary(col("a", &schema)?, Operator::Eq, outer_sq_expr, &schema)?; + let outer_filter = Arc::new(FilterExec::try_new(outer_predicate, inner_exec)?); + let outer_exec: Arc = Arc::new(ScalarSubqueryExec::new( + outer_filter, + vec![ScalarSubqueryLink { + plan: Arc::new(EmptyExec::new(subquery_schema)), + index: SubqueryIndex::new(0), + }], + outer_results, + )); + + let bytes = datafusion_proto::bytes::physical_plan_to_bytes(Arc::clone(&outer_exec))?; + let ctx = SessionContext::new(); + let deserialized = datafusion_proto::bytes::physical_plan_from_bytes( + bytes.as_ref(), + ctx.task_ctx().as_ref(), + )?; + + let outer_exec = deserialized + .downcast_ref::() + .expect("expected outer ScalarSubqueryExec"); + let outer_results = outer_exec.results(); + let outer_filter = outer_exec + .input() + .downcast_ref::() + .expect("expected outer FilterExec"); + let outer_binary = outer_filter + .predicate() + .downcast_ref::() + .expect("expected outer BinaryExpr"); + let outer_sq_expr = outer_binary + .right() + .downcast_ref::() + .expect("expected outer ScalarSubqueryExpr"); + + let inner_exec = outer_filter + .input() + .downcast_ref::() + .expect("expected inner ScalarSubqueryExec"); + let inner_results = inner_exec.results(); + let inner_filter = inner_exec + .input() + .downcast_ref::() + .expect("expected inner FilterExec"); + let inner_binary = inner_filter + .predicate() + .downcast_ref::() + .expect("expected inner BinaryExpr"); + let inner_sq_expr = inner_binary + .right() + .downcast_ref::() + .expect("expected inner ScalarSubqueryExpr"); + + assert!( + ScalarSubqueryResults::ptr_eq(outer_results, outer_sq_expr.results()), + "outer ScalarSubqueryExpr should use outer ScalarSubqueryExec results" + ); + assert!( + ScalarSubqueryResults::ptr_eq(inner_results, inner_sq_expr.results()), + "inner ScalarSubqueryExpr should use inner ScalarSubqueryExec results" + ); + assert!( + !ScalarSubqueryResults::ptr_eq(outer_results, inner_results), + "nested ScalarSubqueryExec nodes should not share results containers" + ); + assert!( + !ScalarSubqueryResults::ptr_eq(outer_results, inner_sq_expr.results()), + "inner ScalarSubqueryExpr must not read from outer results" + ); + assert!( + !ScalarSubqueryResults::ptr_eq(inner_results, outer_sq_expr.results()), + "outer ScalarSubqueryExpr must not read from inner results" + ); + + Ok(()) +} + +/// Verify that the default physical plan bytes round-trip preserves executable +/// scalar subquery plans. +#[tokio::test] +async fn roundtrip_scalar_subquery_exec_with_default_converter_executes() -> Result<()> { + let ctx = SessionContext::new(); + let sql = "SELECT x + (SELECT max(y) FROM (VALUES (10), (20)) AS u(y)) AS s \ + FROM (VALUES (2), (1)) AS t(x) \ + ORDER BY s"; + + let initial_plan = ctx.sql(sql).await?.create_physical_plan().await?; + assert!( + format!("{initial_plan:?}").contains("ScalarSubqueryExec"), + "expected ScalarSubqueryExec in plan:\n{initial_plan:?}" + ); + + let bytes = + datafusion_proto::bytes::physical_plan_to_bytes(Arc::clone(&initial_plan))?; + let roundtripped = datafusion_proto::bytes::physical_plan_from_bytes( + bytes.as_ref(), + ctx.task_ctx().as_ref(), + )?; + assert!( + format!("{roundtripped:?}").contains("ScalarSubqueryExec"), + "expected ScalarSubqueryExec after roundtrip:\n{roundtripped:?}" + ); + + let batches = datafusion::physical_plan::common::collect( + roundtripped.execute(0, ctx.task_ctx())?, + ) + .await?; + datafusion::assert_batches_eq!( + &["+----+", "| s |", "+----+", "| 21 |", "| 22 |", "+----+",], + &batches + ); + + Ok(()) +} + /// Test that a chain of the same operator (a AND b AND c) is linearized /// and roundtrips correctly. #[test] diff --git a/datafusion/proto/tests/cases/serialize.rs b/datafusion/proto/tests/cases/serialize.rs index bb955a426ca78..850fd42ce131b 100644 --- a/datafusion/proto/tests/cases/serialize.rs +++ b/datafusion/proto/tests/cases/serialize.rs @@ -77,7 +77,8 @@ fn udf_roundtrip_with_registry() { .call(vec![lit("")]); let bytes = expr.to_bytes().unwrap(); - let deserialized_expr = Expr::from_bytes_with_registry(&bytes, &ctx).unwrap(); + let deserialized_expr = + Expr::from_bytes_with_ctx(&bytes, ctx.task_ctx().as_ref()).unwrap(); assert_eq!(expr, deserialized_expr); } @@ -281,7 +282,8 @@ fn test_expression_serialization_roundtrip() { let extension_codec = DefaultLogicalExtensionCodec {}; let proto = serialize_expr(&expr, &extension_codec).unwrap(); - let deserialize = parse_expr(&proto, &ctx, &extension_codec).unwrap(); + let deserialize = + parse_expr(&proto, ctx.task_ctx().as_ref(), &extension_codec).unwrap(); let serialize_name = extract_function_name(&expr); let deserialize_name = extract_function_name(&deserialize); diff --git a/datafusion/sqllogictest/test_files/metadata.slt b/datafusion/sqllogictest/test_files/metadata.slt index f584e3c342271..3fea8df260f05 100644 --- a/datafusion/sqllogictest/test_files/metadata.slt +++ b/datafusion/sqllogictest/test_files/metadata.slt @@ -48,11 +48,9 @@ NULL bar 3 baz query I rowsort -SELECT ( - SELECT id FROM table_with_metadata - ) UNION ( - SELECT id FROM table_with_metadata - ); +SELECT id FROM table_with_metadata +UNION +SELECT id FROM table_with_metadata; ---- 1 3 diff --git a/datafusion/sqllogictest/test_files/subquery.slt b/datafusion/sqllogictest/test_files/subquery.slt index 368d252e25006..e70f24303e191 100644 --- a/datafusion/sqllogictest/test_files/subquery.slt +++ b/datafusion/sqllogictest/test_files/subquery.slt @@ -688,8 +688,10 @@ query TT explain SELECT t1_id, (SELECT t2_id FROM t2 limit 0) FROM t1 ---- logical_plan -01)Projection: t1.t1_id, Int32(NULL) AS t2_id -02)--TableScan: t1 projection=[t1_id] +01)Projection: t1.t1_id, () +02)--Subquery: +03)----EmptyRelation: rows=0 +04)--TableScan: t1 projection=[t1_id] query II rowsort SELECT t1_id, (SELECT t2_id FROM t2 limit 0) FROM t1 @@ -720,26 +722,79 @@ query TT explain select (select count(*) from t1) as b ---- logical_plan -01)Projection: __scalar_sq_1.count(*) AS b -02)--SubqueryAlias: __scalar_sq_1 +01)Projection: () AS b +02)--Subquery: 03)----Projection: count(Int64(1)) AS count(*) 04)------Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] 05)--------TableScan: t1 projection=[] +06)--EmptyRelation: rows=1 #simple_uncorrelated_scalar_subquery2 query TT explain select (select count(*) from t1) as b, (select count(1) from t2) ---- logical_plan -01)Projection: __scalar_sq_1.count(*) AS b, __scalar_sq_2.count(Int64(1)) AS count(Int64(1)) -02)--Left Join: -03)----SubqueryAlias: __scalar_sq_1 -04)------Projection: count(Int64(1)) AS count(*) -05)--------Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] -06)----------TableScan: t1 projection=[] -07)----SubqueryAlias: __scalar_sq_2 -08)------Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] -09)--------TableScan: t2 projection=[] +01)Projection: () AS b, () +02)--Subquery: +03)----Projection: count(Int64(1)) AS count(*) +04)------Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] +05)--------TableScan: t1 projection=[] +06)--Subquery: +07)----Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] +08)------TableScan: t2 projection=[] +09)--EmptyRelation: rows=1 + +# Verify projection pushdown works inside uncorrelated scalar subqueries. +# Each test targets a different early-return path in OptimizeProjections +# to ensure subquery plans are optimized regardless of where the subquery +# expression appears. + +# Subquery in a Filter predicate: the TableScan inside the subquery should +# only read t2_id, not all columns. +query TT +explain select t1_id from t1 where t1_id > (select max(t2_id) from t2) +---- +logical_plan +01)Filter: t1.t1_id > () +02)--Subquery: +03)----Aggregate: groupBy=[[]], aggr=[[max(t2.t2_id)]] +04)------TableScan: t2 projection=[t2_id] +05)--TableScan: t1 projection=[t1_id] + +# Subquery in a Projection expression +query TT +explain select t1_id, (select max(t2_id) from t2) as max_t2 from t1 +---- +logical_plan +01)Projection: t1.t1_id, () AS max_t2 +02)--Subquery: +03)----Aggregate: groupBy=[[]], aggr=[[max(t2.t2_id)]] +04)------TableScan: t2 projection=[t2_id] +05)--TableScan: t1 projection=[t1_id] + +# Subquery in an Aggregate expression +query TT +explain select sum(t1_int + (select min(t2_int) from t2)) as s from t1 +---- +logical_plan +01)Projection: sum(t1.t1_int + min(t2.t2_int)) AS s +02)--Aggregate: groupBy=[[]], aggr=[[sum(CAST(t1.t1_int + () AS Int64))]] +03)----Subquery: +04)------Aggregate: groupBy=[[]], aggr=[[min(t2.t2_int)]] +05)--------TableScan: t2 projection=[t2_int] +06)----TableScan: t1 projection=[t1_int] + +# Subquery in a Window expression +query TT +explain select t1_id, sum(t1_int + (select min(t2_int) from t2)) over () as win from t1 +---- +logical_plan +01)Projection: t1.t1_id, sum(t1.t1_int + min(t2.t2_int)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS win +02)--WindowAggr: windowExpr=[[sum(CAST(t1.t1_int + () AS Int64)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] +03)----Subquery: +04)------Aggregate: groupBy=[[]], aggr=[[min(t2.t2_int)]] +05)--------TableScan: t2 projection=[t2_int] +06)----TableScan: t1 projection=[t1_id, t1_int] statement ok set datafusion.explain.logical_plan_only = false; @@ -748,22 +803,23 @@ query TT explain select (select count(*) from t1) as b, (select count(1) from t2) ---- logical_plan -01)Projection: __scalar_sq_1.count(*) AS b, __scalar_sq_2.count(Int64(1)) AS count(Int64(1)) -02)--Left Join: -03)----SubqueryAlias: __scalar_sq_1 -04)------Projection: count(Int64(1)) AS count(*) -05)--------Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] -06)----------TableScan: t1 projection=[] -07)----SubqueryAlias: __scalar_sq_2 -08)------Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] -09)--------TableScan: t2 projection=[] +01)Projection: () AS b, () +02)--Subquery: +03)----Projection: count(Int64(1)) AS count(*) +04)------Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] +05)--------TableScan: t1 projection=[] +06)--Subquery: +07)----Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] +08)------TableScan: t2 projection=[] +09)--EmptyRelation: rows=1 physical_plan -01)ProjectionExec: expr=[count(*)@0 as b, count(Int64(1))@1 as count(Int64(1))] -02)--NestedLoopJoinExec: join_type=Left -03)----ProjectionExec: expr=[4 as count(*)] -04)------PlaceholderRowExec -05)----ProjectionExec: expr=[4 as count(Int64(1))] -06)------PlaceholderRowExec +01)ScalarSubqueryExec: subqueries=2 +02)--ProjectionExec: expr=[scalar_subquery() as b, scalar_subquery() as count(Int64(1))] +03)----PlaceholderRowExec +04)--ProjectionExec: expr=[4 as count(*)] +05)----PlaceholderRowExec +06)--ProjectionExec: expr=[4 as count(Int64(1))] +07)----PlaceholderRowExec statement ok set datafusion.explain.logical_plan_only = true; @@ -1669,6 +1725,316 @@ drop table employees; statement count 0 drop table project_assignments; +############# +## Uncorrelated scalar subquery row-count semantics +## A scalar subquery must return at most one row; returning more is an error. +############# + +statement ok +CREATE TABLE sq_values(v INT) AS VALUES (1), (2), (3); + +statement ok +CREATE TABLE sq_main(x INT) AS VALUES (10), (20); + +statement ok +CREATE TABLE sq_empty(v INT) AS VALUES (1) LIMIT 0; + +# Scalar subquery returning multiple rows in SELECT position → error +query error DataFusion error: Execution error: Scalar subquery returned more than one row +SELECT (SELECT v FROM sq_values); + +# Scalar subquery returning multiple rows in WHERE position → error +query error DataFusion error: Execution error: Scalar subquery returned more than one row +SELECT x FROM sq_main WHERE x > (SELECT v FROM sq_values); + +# Scalar subquery returning multiple rows as a function argument → error +query error DataFusion error: Execution error: Scalar subquery returned more than one row +SELECT x + (SELECT v FROM sq_values) FROM sq_main; + +# Scalar subquery returning exactly one row → success +query I +SELECT (SELECT v FROM sq_values LIMIT 1); +---- +1 + +# Scalar subquery returning exactly one row in WHERE → success +query I rowsort +SELECT x FROM sq_main WHERE x > (SELECT v FROM sq_values LIMIT 1); +---- +10 +20 + +# Scalar subquery returning zero rows → NULL +query I +SELECT (SELECT v FROM sq_empty); +---- +NULL + +# Scalar subquery returning zero rows in arithmetic → NULL propagation +query I +SELECT x + (SELECT v FROM sq_empty) FROM sq_main; +---- +NULL +NULL + +# Scalar subquery returning zero rows in WHERE comparison → no matching rows +query I +SELECT x FROM sq_main WHERE x > (SELECT v FROM sq_empty); +---- + +# Aggregated subquery always returns one row, even on empty input → success +query I +SELECT (SELECT count(*) FROM sq_empty); +---- +0 + +# Aggregated subquery on multi-row table → success (aggregation reduces to 1 row) +query I +SELECT (SELECT max(v) FROM sq_values); +---- +3 + +# Multiple scalar subqueries, one returns multiple rows → error +query error DataFusion error: Execution error: Scalar subquery returned more than one row +SELECT (SELECT count(*) FROM sq_empty), (SELECT v FROM sq_values); + +############# +## Uncorrelated scalar subqueries in various expression contexts +############# + +# HAVING clause with uncorrelated scalar subquery +query II rowsort +SELECT x, count(*) AS cnt FROM sq_main GROUP BY x +HAVING count(*) > (SELECT min(v) FROM sq_values); +---- + +# CASE WHEN with uncorrelated scalar subquery as condition +query T rowsort +SELECT CASE WHEN x > (SELECT min(v) FROM sq_values) + THEN 'big' ELSE 'small' END AS label +FROM sq_main; +---- +big +big + +# ORDER BY with uncorrelated scalar subquery +query I +SELECT x FROM sq_main ORDER BY x + (SELECT max(v) FROM sq_values); +---- +10 +20 + +# Aggregate function argument containing uncorrelated scalar subquery +query I +SELECT sum(x + (SELECT max(v) FROM sq_values)) AS s FROM sq_main; +---- +36 + +# JOIN ON condition with uncorrelated scalar subquery +query II rowsort +SELECT l.x, r.x AS rx +FROM sq_main AS l JOIN sq_main AS r +ON l.x = r.x + (SELECT min(v) FROM sq_values); +---- + +# Nested uncorrelated-in-uncorrelated scalar subquery. +query I +SELECT (SELECT max(v) + (SELECT min(v) FROM sq_values) FROM sq_values); +---- +4 + +# Verify nested subqueries are not hoisted: the root ScalarSubqueryExec +# should manage only the outer subquery (subqueries=1), not both. +query TT +EXPLAIN SELECT (SELECT max(v) + (SELECT min(v) FROM sq_values) FROM sq_values); +---- +logical_plan +01)Projection: () +02)--Subquery: +03)----Projection: max(sq_values.v) + () +04)------Subquery: +05)--------Aggregate: groupBy=[[]], aggr=[[min(sq_values.v)]] +06)----------TableScan: sq_values projection=[v] +07)------Aggregate: groupBy=[[]], aggr=[[max(sq_values.v)]] +08)--------TableScan: sq_values projection=[v] +09)--EmptyRelation: rows=1 +physical_plan +01)ScalarSubqueryExec: subqueries=1 +02)--ProjectionExec: expr=[scalar_subquery() as max(sq_values.v) + min(sq_values.v)] +03)----PlaceholderRowExec +04)--ScalarSubqueryExec: subqueries=1 +05)----ProjectionExec: expr=[max(sq_values.v)@0 + scalar_subquery() as max(sq_values.v) + min(sq_values.v)] +06)------AggregateExec: mode=Single, gby=[], aggr=[max(sq_values.v)] +07)--------DataSourceExec: partitions=1, partition_sizes=[1] +08)----AggregateExec: mode=Single, gby=[], aggr=[min(sq_values.v)] +09)------DataSourceExec: partitions=1, partition_sizes=[1] + +# CTE as source inside uncorrelated scalar subquery +query I +SELECT (SELECT s FROM (WITH cte AS (SELECT max(v) AS s FROM sq_values) SELECT s FROM cte)); +---- +3 + +# Window function with uncorrelated scalar subquery +query II rowsort +SELECT x, sum(x + (SELECT max(v) FROM sq_values)) OVER () AS win_sum FROM sq_main; +---- +10 36 +20 36 + +# Duplicate uncorrelated scalar subqueries only appear in the query plan once +statement ok +set datafusion.explain.logical_plan_only = false; + +query TT +explain SELECT (SELECT max(v) FROM sq_values) + (SELECT max(v) FROM sq_values) AS doubled; +---- +logical_plan +01)Projection: __common_expr_1 + __common_expr_1 AS doubled +02)--Projection: () AS __common_expr_1 +03)----Subquery: +04)------Aggregate: groupBy=[[]], aggr=[[max(sq_values.v)]] +05)--------TableScan: sq_values projection=[v] +06)----EmptyRelation: rows=1 +physical_plan +01)ScalarSubqueryExec: subqueries=1 +02)--ProjectionExec: expr=[__common_expr_1@0 + __common_expr_1@0 as doubled] +03)----ProjectionExec: expr=[scalar_subquery() as __common_expr_1] +04)------PlaceholderRowExec +05)--AggregateExec: mode=Single, gby=[], aggr=[max(sq_values.v)] +06)----DataSourceExec: partitions=1, partition_sizes=[1] + +statement ok +RESET datafusion.explain.logical_plan_only; + +############# +## Additional edge cases inspired by DuckDB scalar subquery bugs and tests. +## References: +## DuckDB issue #13639: volatile functions in uncorrelated subqueries +## DuckDB issue #13469 / PR #13514: multi-row error semantics +## DuckDB issue #4113: uncorrelated scalar subquery wrong results +## DuckDB test: test/sql/subquery/scalar/test_uncorrelated_scalar_subquery.test +## DuckDB test: test/sql/subquery/scalar/test_scalar_subquery.test +## DuckDB test: test/sql/subquery/scalar/test_scalar_subquery_cte.test +## DuckDB test: test/sql/order/test_limit.test +## DuckDB PR #8519: TopN optimization with scalar subquery in LIMIT/OFFSET +############# + +# Volatile function in uncorrelated subquery: random() is evaluated once and +# the same value is used for every outer row. Both Postgres and DuckDB (#13639) +# exhibit this behavior; it is correct per the SQL standard. +# We cannot assert the exact value, but we can verify all rows are identical. +query B +SELECT count(DISTINCT r) = 1 FROM ( + SELECT (SELECT random()) AS r + FROM (VALUES (1), (2), (3)) AS t(x) +); +---- +true + +# Subquery as GROUP BY key. +# Ref: DuckDB test_uncorrelated_scalar_subquery.test, test #6 +query II +SELECT (SELECT 42) AS k, max(x) FROM (VALUES (1), (2), (3)) AS t(x) GROUP BY k; +---- +42 3 + +# Subquery inside an aggregate function argument. +# Ref: DuckDB test_uncorrelated_scalar_subquery.test, test #7 +query II +SELECT x, max((SELECT 42)) FROM (VALUES (1), (2), (3)) AS t(x) GROUP BY x ORDER BY x; +---- +1 42 +2 42 +3 42 + +# Doubly-nested constant subquery. +# Ref: DuckDB test_scalar_subquery.test +query I +SELECT (SELECT (SELECT 42)); +---- +42 + +# Triple-nested constant subquery. +query I +SELECT (SELECT (SELECT (SELECT 99))); +---- +99 + +# Star expansion: single column is OK. +# Ref: DuckDB test_uncorrelated_scalar_subquery.test, tests #16-17 +query I +SELECT (SELECT * FROM (VALUES (1)) AS t(x)); +---- +1 + +# Star expansion: two columns must error. +query error Too many columns +SELECT (SELECT * FROM (VALUES (1, 2)) AS t(x, y)); + +# Subquery in BETWEEN bounds. +query I +SELECT x FROM (VALUES (1), (2), (3), (4), (5)) AS t(x) +WHERE x BETWEEN (SELECT 2) AND (SELECT 4) +ORDER BY x; +---- +2 +3 +4 + +# DISTINCT subquery returning exactly one distinct value (multi-row input). +query I +SELECT (SELECT DISTINCT 42 FROM (VALUES (1), (2), (3)) AS t(x)); +---- +42 + +# DISTINCT subquery returning multiple distinct values must error. +query error DataFusion error: Execution error: Scalar subquery returned more than one row +SELECT (SELECT DISTINCT x FROM (VALUES (1), (2), (3)) AS t(x)); + +# NULL comparison semantics through subquery boundary. +# Ref: DuckDB test_scalar_subquery.test, NULL edge cases +query B +SELECT 1 = (SELECT CAST(NULL AS INT)); +---- +NULL + +query B +SELECT CAST(NULL AS INT) = (SELECT 1); +---- +NULL + +# Nested CTE inside scalar subquery. +# Ref: DuckDB test_scalar_subquery_cte.test, nested CTE test +query I +SELECT (WITH cte1 AS (WITH cte2 AS (SELECT 42) SELECT * FROM cte2) SELECT * FROM cte1); +---- +42 + +# Subquery in LIMIT: not yet supported, verify clear error message. +# Ref: DuckDB PR #8519, DuckDB test/sql/order/test_limit.test +query error This feature is not implemented: Unsupported LIMIT expression +SELECT * FROM (VALUES (1), (2), (3)) AS t(x) ORDER BY x LIMIT (SELECT 2); + +# Subquery in OFFSET: not yet supported, verify clear error message. +query error This feature is not implemented: Unsupported OFFSET expression +SELECT * FROM (VALUES (1), (2), (3)) AS t(x) ORDER BY x OFFSET (SELECT 1); + +# UNION ALL subquery with ORDER BY + LIMIT 1 to avoid multi-row error. +query I +SELECT (SELECT v FROM (SELECT 1 AS v UNION ALL SELECT 2) AS t ORDER BY v LIMIT 1); +---- +1 + +statement count 0 +DROP TABLE sq_values; + +statement count 0 +DROP TABLE sq_main; + +statement count 0 +DROP TABLE sq_empty; + # https://github.com/apache/datafusion/issues/21205 statement ok CREATE TABLE dup_filter_t1(id INTEGER) AS VALUES (1), (2), (3); @@ -1736,3 +2102,38 @@ DROP TABLE sq_name_t1; statement ok DROP TABLE sq_name_t2; + +# Test: scalar subquery in filter on a partition column of a partitioned table. +# This exercises the code path where filters are pushed down to the table +# provider for partition pruning. Scalar subqueries must not be pushed to the +# provider because the subquery result is not available at partition listing +# time. + +query I +COPY (VALUES(1, 'a'), (2, 'b'), (3, 'c')) +TO 'test_files/scratch/subquery/partition_pruning/part=1/file1.parquet'; +---- +3 + +query I +COPY (VALUES(4, 'd'), (5, 'e')) +TO 'test_files/scratch/subquery/partition_pruning/part=2/file1.parquet'; +---- +2 + +statement ok +CREATE EXTERNAL TABLE subquery_partitioned +STORED AS PARQUET +LOCATION 'test_files/scratch/subquery/partition_pruning/'; + +query IT +SELECT column1, column2 FROM subquery_partitioned +WHERE part = (SELECT 1) +ORDER BY column1; +---- +1 a +2 b +3 c + +statement ok +DROP TABLE subquery_partitioned; diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q11.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q11.slt.part index a31579eb1e09d..0c5b6d76dc1e1 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q11.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q11.slt.part @@ -49,61 +49,62 @@ limit 10; logical_plan 01)Sort: value DESC NULLS FIRST, fetch=10 02)--Projection: partsupp.ps_partkey, sum(partsupp.ps_supplycost * partsupp.ps_availqty) AS value -03)----Inner Join: Filter: CAST(sum(partsupp.ps_supplycost * partsupp.ps_availqty) AS Decimal128(38, 15)) > __scalar_sq_1.sum(partsupp.ps_supplycost * partsupp.ps_availqty) * Float64(0.0001) -04)------Aggregate: groupBy=[[partsupp.ps_partkey]], aggr=[[sum(partsupp.ps_supplycost * CAST(partsupp.ps_availqty AS Decimal128(10, 0)))]] -05)--------Projection: partsupp.ps_partkey, partsupp.ps_availqty, partsupp.ps_supplycost -06)----------Inner Join: supplier.s_nationkey = nation.n_nationkey -07)------------Projection: partsupp.ps_partkey, partsupp.ps_availqty, partsupp.ps_supplycost, supplier.s_nationkey -08)--------------Inner Join: partsupp.ps_suppkey = supplier.s_suppkey -09)----------------TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost], partial_filters=[Boolean(true)] -10)----------------TableScan: supplier projection=[s_suppkey, s_nationkey] -11)------------Projection: nation.n_nationkey -12)--------------Filter: nation.n_name = Utf8View("GERMANY") -13)----------------TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8View("GERMANY")] -14)------SubqueryAlias: __scalar_sq_1 -15)--------Projection: CAST(CAST(sum(partsupp.ps_supplycost * partsupp.ps_availqty) AS Float64) * Float64(0.0001) AS Decimal128(38, 15)) -16)----------Aggregate: groupBy=[[]], aggr=[[sum(partsupp.ps_supplycost * CAST(partsupp.ps_availqty AS Decimal128(10, 0)))]] -17)------------Projection: partsupp.ps_availqty, partsupp.ps_supplycost -18)--------------Inner Join: supplier.s_nationkey = nation.n_nationkey -19)----------------Projection: partsupp.ps_availqty, partsupp.ps_supplycost, supplier.s_nationkey -20)------------------Inner Join: partsupp.ps_suppkey = supplier.s_suppkey -21)--------------------TableScan: partsupp projection=[ps_suppkey, ps_availqty, ps_supplycost] -22)--------------------TableScan: supplier projection=[s_suppkey, s_nationkey] -23)----------------Projection: nation.n_nationkey -24)------------------Filter: nation.n_name = Utf8View("GERMANY") -25)--------------------TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8View("GERMANY")] +03)----Filter: CAST(sum(partsupp.ps_supplycost * partsupp.ps_availqty) AS Decimal128(38, 15)) > () +04)------Subquery: +05)--------Projection: CAST(CAST(sum(partsupp.ps_supplycost * partsupp.ps_availqty) AS Float64) * Float64(0.0001) AS Decimal128(38, 15)) +06)----------Aggregate: groupBy=[[]], aggr=[[sum(partsupp.ps_supplycost * CAST(partsupp.ps_availqty AS Decimal128(10, 0)))]] +07)------------Projection: partsupp.ps_availqty, partsupp.ps_supplycost +08)--------------Inner Join: supplier.s_nationkey = nation.n_nationkey +09)----------------Projection: partsupp.ps_availqty, partsupp.ps_supplycost, supplier.s_nationkey +10)------------------Inner Join: partsupp.ps_suppkey = supplier.s_suppkey +11)--------------------TableScan: partsupp projection=[ps_suppkey, ps_availqty, ps_supplycost] +12)--------------------TableScan: supplier projection=[s_suppkey, s_nationkey] +13)----------------Projection: nation.n_nationkey +14)------------------Filter: nation.n_name = Utf8View("GERMANY") +15)--------------------TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8View("GERMANY")] +16)------Aggregate: groupBy=[[partsupp.ps_partkey]], aggr=[[sum(partsupp.ps_supplycost * CAST(partsupp.ps_availqty AS Decimal128(10, 0)))]] +17)--------Projection: partsupp.ps_partkey, partsupp.ps_availqty, partsupp.ps_supplycost +18)----------Inner Join: supplier.s_nationkey = nation.n_nationkey +19)------------Projection: partsupp.ps_partkey, partsupp.ps_availqty, partsupp.ps_supplycost, supplier.s_nationkey +20)--------------Inner Join: partsupp.ps_suppkey = supplier.s_suppkey +21)----------------TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost] +22)----------------TableScan: supplier projection=[s_suppkey, s_nationkey] +23)------------Projection: nation.n_nationkey +24)--------------Filter: nation.n_name = Utf8View("GERMANY") +25)----------------TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8View("GERMANY")] physical_plan -01)SortExec: TopK(fetch=10), expr=[value@1 DESC], preserve_partitioning=[false] -02)--ProjectionExec: expr=[ps_partkey@0 as ps_partkey, sum(partsupp.ps_supplycost * partsupp.ps_availqty)@1 as value] -03)----NestedLoopJoinExec: join_type=Inner, filter=join_proj_push_down_1@1 > sum(partsupp.ps_supplycost * partsupp.ps_availqty) * Float64(0.0001)@0, projection=[ps_partkey@0, sum(partsupp.ps_supplycost * partsupp.ps_availqty)@1, sum(partsupp.ps_supplycost * partsupp.ps_availqty) * Float64(0.0001)@3] -04)------ProjectionExec: expr=[ps_partkey@0 as ps_partkey, sum(partsupp.ps_supplycost * partsupp.ps_availqty)@1 as sum(partsupp.ps_supplycost * partsupp.ps_availqty), CAST(sum(partsupp.ps_supplycost * partsupp.ps_availqty)@1 AS Decimal128(38, 15)) as join_proj_push_down_1] -05)--------CoalescePartitionsExec -06)----------AggregateExec: mode=FinalPartitioned, gby=[ps_partkey@0 as ps_partkey], aggr=[sum(partsupp.ps_supplycost * partsupp.ps_availqty)] -07)------------RepartitionExec: partitioning=Hash([ps_partkey@0], 4), input_partitions=4 -08)--------------AggregateExec: mode=Partial, gby=[ps_partkey@0 as ps_partkey], aggr=[sum(partsupp.ps_supplycost * partsupp.ps_availqty)] -09)----------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(s_nationkey@3, n_nationkey@0)], projection=[ps_partkey@0, ps_availqty@1, ps_supplycost@2] -10)------------------RepartitionExec: partitioning=Hash([s_nationkey@3], 4), input_partitions=4 -11)--------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(ps_suppkey@1, s_suppkey@0)], projection=[ps_partkey@0, ps_availqty@2, ps_supplycost@3, s_nationkey@5] -12)----------------------RepartitionExec: partitioning=Hash([ps_suppkey@1], 4), input_partitions=4 -13)------------------------DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:0..2932049], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:2932049..5864098], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:5864098..8796147], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:8796147..11728193]]}, projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost], file_type=csv, has_header=false -14)----------------------RepartitionExec: partitioning=Hash([s_suppkey@0], 4), input_partitions=1 -15)------------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/supplier.tbl]]}, projection=[s_suppkey, s_nationkey], file_type=csv, has_header=false -16)------------------RepartitionExec: partitioning=Hash([n_nationkey@0], 4), input_partitions=4 -17)--------------------FilterExec: n_name@1 = GERMANY, projection=[n_nationkey@0] -18)----------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -19)------------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/nation.tbl]]}, projection=[n_nationkey, n_name], file_type=csv, has_header=false -20)------ProjectionExec: expr=[CAST(CAST(sum(partsupp.ps_supplycost * partsupp.ps_availqty)@0 AS Float64) * 0.0001 AS Decimal128(38, 15)) as sum(partsupp.ps_supplycost * partsupp.ps_availqty) * Float64(0.0001)] -21)--------AggregateExec: mode=Final, gby=[], aggr=[sum(partsupp.ps_supplycost * partsupp.ps_availqty)] -22)----------CoalescePartitionsExec -23)------------AggregateExec: mode=Partial, gby=[], aggr=[sum(partsupp.ps_supplycost * partsupp.ps_availqty)] -24)--------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(s_nationkey@2, n_nationkey@0)], projection=[ps_availqty@0, ps_supplycost@1] -25)----------------RepartitionExec: partitioning=Hash([s_nationkey@2], 4), input_partitions=4 -26)------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(ps_suppkey@0, s_suppkey@0)], projection=[ps_availqty@1, ps_supplycost@2, s_nationkey@4] -27)--------------------RepartitionExec: partitioning=Hash([ps_suppkey@0], 4), input_partitions=4 -28)----------------------DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:0..2932049], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:2932049..5864098], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:5864098..8796147], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:8796147..11728193]]}, projection=[ps_suppkey, ps_availqty, ps_supplycost], file_type=csv, has_header=false -29)--------------------RepartitionExec: partitioning=Hash([s_suppkey@0], 4), input_partitions=1 -30)----------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/supplier.tbl]]}, projection=[s_suppkey, s_nationkey], file_type=csv, has_header=false -31)----------------RepartitionExec: partitioning=Hash([n_nationkey@0], 4), input_partitions=4 -32)------------------FilterExec: n_name@1 = GERMANY, projection=[n_nationkey@0] -33)--------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -34)----------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/nation.tbl]]}, projection=[n_nationkey, n_name], file_type=csv, has_header=false +01)ScalarSubqueryExec: subqueries=1 +02)--RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1, maintains_sort_order=true +03)----SortPreservingMergeExec: [value@1 DESC], fetch=10 +04)------SortExec: TopK(fetch=10), expr=[value@1 DESC], preserve_partitioning=[true] +05)--------ProjectionExec: expr=[ps_partkey@0 as ps_partkey, sum(partsupp.ps_supplycost * partsupp.ps_availqty)@1 as value] +06)----------FilterExec: CAST(sum(partsupp.ps_supplycost * partsupp.ps_availqty)@1 AS Decimal128(38, 15)) > scalar_subquery() +07)------------AggregateExec: mode=FinalPartitioned, gby=[ps_partkey@0 as ps_partkey], aggr=[sum(partsupp.ps_supplycost * partsupp.ps_availqty)] +08)--------------RepartitionExec: partitioning=Hash([ps_partkey@0], 4), input_partitions=4 +09)----------------AggregateExec: mode=Partial, gby=[ps_partkey@0 as ps_partkey], aggr=[sum(partsupp.ps_supplycost * partsupp.ps_availqty)] +10)------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(s_nationkey@3, n_nationkey@0)], projection=[ps_partkey@0, ps_availqty@1, ps_supplycost@2] +11)--------------------RepartitionExec: partitioning=Hash([s_nationkey@3], 4), input_partitions=4 +12)----------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(ps_suppkey@1, s_suppkey@0)], projection=[ps_partkey@0, ps_availqty@2, ps_supplycost@3, s_nationkey@5] +13)------------------------RepartitionExec: partitioning=Hash([ps_suppkey@1], 4), input_partitions=4 +14)--------------------------DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:0..2932049], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:2932049..5864098], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:5864098..8796147], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:8796147..11728193]]}, projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost], file_type=csv, has_header=false +15)------------------------RepartitionExec: partitioning=Hash([s_suppkey@0], 4), input_partitions=1 +16)--------------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/supplier.tbl]]}, projection=[s_suppkey, s_nationkey], file_type=csv, has_header=false +17)--------------------RepartitionExec: partitioning=Hash([n_nationkey@0], 4), input_partitions=4 +18)----------------------FilterExec: n_name@1 = GERMANY, projection=[n_nationkey@0] +19)------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +20)--------------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/nation.tbl]]}, projection=[n_nationkey, n_name], file_type=csv, has_header=false +21)--ProjectionExec: expr=[CAST(CAST(sum(partsupp.ps_supplycost * partsupp.ps_availqty)@0 AS Float64) * 0.0001 AS Decimal128(38, 15)) as sum(partsupp.ps_supplycost * partsupp.ps_availqty) * Float64(0.0001)] +22)----AggregateExec: mode=Final, gby=[], aggr=[sum(partsupp.ps_supplycost * partsupp.ps_availqty)] +23)------CoalescePartitionsExec +24)--------AggregateExec: mode=Partial, gby=[], aggr=[sum(partsupp.ps_supplycost * partsupp.ps_availqty)] +25)----------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(s_nationkey@2, n_nationkey@0)], projection=[ps_availqty@0, ps_supplycost@1] +26)------------RepartitionExec: partitioning=Hash([s_nationkey@2], 4), input_partitions=4 +27)--------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(ps_suppkey@0, s_suppkey@0)], projection=[ps_availqty@1, ps_supplycost@2, s_nationkey@4] +28)----------------RepartitionExec: partitioning=Hash([ps_suppkey@0], 4), input_partitions=4 +29)------------------DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:0..2932049], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:2932049..5864098], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:5864098..8796147], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:8796147..11728193]]}, projection=[ps_suppkey, ps_availqty, ps_supplycost], file_type=csv, has_header=false +30)----------------RepartitionExec: partitioning=Hash([s_suppkey@0], 4), input_partitions=1 +31)------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/supplier.tbl]]}, projection=[s_suppkey, s_nationkey], file_type=csv, has_header=false +32)------------RepartitionExec: partitioning=Hash([n_nationkey@0], 4), input_partitions=4 +33)--------------FilterExec: n_name@1 = GERMANY, projection=[n_nationkey@0] +34)----------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +35)------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/nation.tbl]]}, projection=[n_nationkey, n_name], file_type=csv, has_header=false diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q15.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q15.slt.part index ae0c0a93a3552..3e1aca318b5c7 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q15.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q15.slt.part @@ -52,43 +52,44 @@ order by logical_plan 01)Sort: supplier.s_suppkey ASC NULLS LAST 02)--Projection: supplier.s_suppkey, supplier.s_name, supplier.s_address, supplier.s_phone, revenue0.total_revenue -03)----Inner Join: revenue0.total_revenue = __scalar_sq_1.max(revenue0.total_revenue) -04)------Projection: supplier.s_suppkey, supplier.s_name, supplier.s_address, supplier.s_phone, revenue0.total_revenue -05)--------Inner Join: supplier.s_suppkey = revenue0.supplier_no -06)----------TableScan: supplier projection=[s_suppkey, s_name, s_address, s_phone], partial_filters=[Boolean(true)] -07)----------SubqueryAlias: revenue0 -08)------------Projection: lineitem.l_suppkey AS supplier_no, sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS total_revenue -09)--------------Aggregate: groupBy=[[lineitem.l_suppkey]], aggr=[[sum(lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)) AS sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]] -10)----------------Projection: lineitem.l_suppkey, lineitem.l_extendedprice, lineitem.l_discount -11)------------------Filter: lineitem.l_shipdate >= Date32("1996-01-01") AND lineitem.l_shipdate < Date32("1996-04-01") -12)--------------------TableScan: lineitem projection=[l_suppkey, l_extendedprice, l_discount, l_shipdate], partial_filters=[lineitem.l_shipdate >= Date32("1996-01-01"), lineitem.l_shipdate < Date32("1996-04-01")] -13)------SubqueryAlias: __scalar_sq_1 -14)--------Aggregate: groupBy=[[]], aggr=[[max(revenue0.total_revenue)]] -15)----------SubqueryAlias: revenue0 -16)------------Projection: sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS total_revenue -17)--------------Aggregate: groupBy=[[lineitem.l_suppkey]], aggr=[[sum(lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)) AS sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]] -18)----------------Projection: lineitem.l_suppkey, lineitem.l_extendedprice, lineitem.l_discount -19)------------------Filter: lineitem.l_shipdate >= Date32("1996-01-01") AND lineitem.l_shipdate < Date32("1996-04-01") -20)--------------------TableScan: lineitem projection=[l_suppkey, l_extendedprice, l_discount, l_shipdate], partial_filters=[lineitem.l_shipdate >= Date32("1996-01-01"), lineitem.l_shipdate < Date32("1996-04-01")] +03)----Inner Join: supplier.s_suppkey = revenue0.supplier_no +04)------TableScan: supplier projection=[s_suppkey, s_name, s_address, s_phone] +05)------SubqueryAlias: revenue0 +06)--------Projection: lineitem.l_suppkey AS supplier_no, sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS total_revenue +07)----------Filter: sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) = () +08)------------Subquery: +09)--------------Aggregate: groupBy=[[]], aggr=[[max(revenue0.total_revenue)]] +10)----------------SubqueryAlias: revenue0 +11)------------------Projection: sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS total_revenue +12)--------------------Aggregate: groupBy=[[lineitem.l_suppkey]], aggr=[[sum(lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)) AS sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]] +13)----------------------Projection: lineitem.l_suppkey, lineitem.l_extendedprice, lineitem.l_discount +14)------------------------Filter: lineitem.l_shipdate >= Date32("1996-01-01") AND lineitem.l_shipdate < Date32("1996-04-01") +15)--------------------------TableScan: lineitem projection=[l_suppkey, l_extendedprice, l_discount, l_shipdate], partial_filters=[lineitem.l_shipdate >= Date32("1996-01-01"), lineitem.l_shipdate < Date32("1996-04-01")] +16)------------Aggregate: groupBy=[[lineitem.l_suppkey]], aggr=[[sum(lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)) AS sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]] +17)--------------Projection: lineitem.l_suppkey, lineitem.l_extendedprice, lineitem.l_discount +18)----------------Filter: lineitem.l_shipdate >= Date32("1996-01-01") AND lineitem.l_shipdate < Date32("1996-04-01") +19)------------------TableScan: lineitem projection=[l_suppkey, l_extendedprice, l_discount, l_shipdate], partial_filters=[lineitem.l_shipdate >= Date32("1996-01-01"), lineitem.l_shipdate < Date32("1996-04-01")] physical_plan -01)SortPreservingMergeExec: [s_suppkey@0 ASC NULLS LAST] -02)--SortExec: expr=[s_suppkey@0 ASC NULLS LAST], preserve_partitioning=[true] -03)----HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(max(revenue0.total_revenue)@0, total_revenue@4)], projection=[s_suppkey@1, s_name@2, s_address@3, s_phone@4, total_revenue@5] -04)------AggregateExec: mode=Final, gby=[], aggr=[max(revenue0.total_revenue)] -05)--------CoalescePartitionsExec -06)----------AggregateExec: mode=Partial, gby=[], aggr=[max(revenue0.total_revenue)] -07)------------ProjectionExec: expr=[sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)@1 as total_revenue] -08)--------------AggregateExec: mode=FinalPartitioned, gby=[l_suppkey@0 as l_suppkey], aggr=[sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] -09)----------------RepartitionExec: partitioning=Hash([l_suppkey@0], 4), input_partitions=4 -10)------------------AggregateExec: mode=Partial, gby=[l_suppkey@0 as l_suppkey], aggr=[sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] -11)--------------------FilterExec: l_shipdate@3 >= 1996-01-01 AND l_shipdate@3 < 1996-04-01, projection=[l_suppkey@0, l_extendedprice@1, l_discount@2] -12)----------------------DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_suppkey, l_extendedprice, l_discount, l_shipdate], file_type=csv, has_header=false -13)------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(s_suppkey@0, supplier_no@0)], projection=[s_suppkey@0, s_name@1, s_address@2, s_phone@3, total_revenue@5] -14)--------RepartitionExec: partitioning=Hash([s_suppkey@0], 4), input_partitions=1 -15)----------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/supplier.tbl]]}, projection=[s_suppkey, s_name, s_address, s_phone], file_type=csv, has_header=false -16)--------ProjectionExec: expr=[l_suppkey@0 as supplier_no, sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)@1 as total_revenue] -17)----------AggregateExec: mode=FinalPartitioned, gby=[l_suppkey@0 as l_suppkey], aggr=[sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] -18)------------RepartitionExec: partitioning=Hash([l_suppkey@0], 4), input_partitions=4 -19)--------------AggregateExec: mode=Partial, gby=[l_suppkey@0 as l_suppkey], aggr=[sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] -20)----------------FilterExec: l_shipdate@3 >= 1996-01-01 AND l_shipdate@3 < 1996-04-01, projection=[l_suppkey@0, l_extendedprice@1, l_discount@2] -21)------------------DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_suppkey, l_extendedprice, l_discount, l_shipdate], file_type=csv, has_header=false +01)ScalarSubqueryExec: subqueries=1 +02)--RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1, maintains_sort_order=true +03)----SortPreservingMergeExec: [s_suppkey@0 ASC NULLS LAST] +04)------SortExec: expr=[s_suppkey@0 ASC NULLS LAST], preserve_partitioning=[true] +05)--------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(s_suppkey@0, supplier_no@0)], projection=[s_suppkey@0, s_name@1, s_address@2, s_phone@3, total_revenue@5] +06)----------RepartitionExec: partitioning=Hash([s_suppkey@0], 4), input_partitions=1 +07)------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/supplier.tbl]]}, projection=[s_suppkey, s_name, s_address, s_phone], file_type=csv, has_header=false +08)----------ProjectionExec: expr=[l_suppkey@0 as supplier_no, sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)@1 as total_revenue] +09)------------FilterExec: sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)@1 = scalar_subquery() +10)--------------AggregateExec: mode=FinalPartitioned, gby=[l_suppkey@0 as l_suppkey], aggr=[sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] +11)----------------RepartitionExec: partitioning=Hash([l_suppkey@0], 4), input_partitions=4 +12)------------------AggregateExec: mode=Partial, gby=[l_suppkey@0 as l_suppkey], aggr=[sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] +13)--------------------FilterExec: l_shipdate@3 >= 1996-01-01 AND l_shipdate@3 < 1996-04-01, projection=[l_suppkey@0, l_extendedprice@1, l_discount@2] +14)----------------------DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_suppkey, l_extendedprice, l_discount, l_shipdate], file_type=csv, has_header=false +15)--AggregateExec: mode=Final, gby=[], aggr=[max(revenue0.total_revenue)] +16)----CoalescePartitionsExec +17)------AggregateExec: mode=Partial, gby=[], aggr=[max(revenue0.total_revenue)] +18)--------ProjectionExec: expr=[sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)@1 as total_revenue] +19)----------AggregateExec: mode=FinalPartitioned, gby=[l_suppkey@0 as l_suppkey], aggr=[sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] +20)------------RepartitionExec: partitioning=Hash([l_suppkey@0], 4), input_partitions=4 +21)--------------AggregateExec: mode=Partial, gby=[l_suppkey@0 as l_suppkey], aggr=[sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] +22)----------------FilterExec: l_shipdate@3 >= 1996-01-01 AND l_shipdate@3 < 1996-04-01, projection=[l_suppkey@0, l_extendedprice@1, l_discount@2] +23)------------------DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_suppkey, l_extendedprice, l_discount, l_shipdate], file_type=csv, has_header=false diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q22.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q22.slt.part index add578c3b079d..3240cbfb697d5 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q22.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q22.slt.part @@ -61,40 +61,36 @@ logical_plan 03)----Aggregate: groupBy=[[custsale.cntrycode]], aggr=[[count(Int64(1)), sum(custsale.c_acctbal)]] 04)------SubqueryAlias: custsale 05)--------Projection: substr(customer.c_phone, Int64(1), Int64(2)) AS cntrycode, customer.c_acctbal -06)----------Inner Join: Filter: CAST(customer.c_acctbal AS Decimal128(19, 6)) > __scalar_sq_2.avg(customer.c_acctbal) -07)------------Projection: customer.c_phone, customer.c_acctbal -08)--------------LeftAnti Join: customer.c_custkey = __correlated_sq_1.o_custkey -09)----------------Filter: substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8View("13"), Utf8View("31"), Utf8View("23"), Utf8View("29"), Utf8View("30"), Utf8View("18"), Utf8View("17")]) -10)------------------TableScan: customer projection=[c_custkey, c_phone, c_acctbal], partial_filters=[substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8View("13"), Utf8View("31"), Utf8View("23"), Utf8View("29"), Utf8View("30"), Utf8View("18"), Utf8View("17")]), Boolean(true)] -11)----------------SubqueryAlias: __correlated_sq_1 -12)------------------TableScan: orders projection=[o_custkey] -13)------------SubqueryAlias: __scalar_sq_2 -14)--------------Aggregate: groupBy=[[]], aggr=[[avg(customer.c_acctbal)]] -15)----------------Projection: customer.c_acctbal -16)------------------Filter: customer.c_acctbal > Decimal128(Some(0),15,2) AND substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8View("13"), Utf8View("31"), Utf8View("23"), Utf8View("29"), Utf8View("30"), Utf8View("18"), Utf8View("17")]) -17)--------------------TableScan: customer projection=[c_phone, c_acctbal], partial_filters=[customer.c_acctbal > Decimal128(Some(0),15,2), substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8View("13"), Utf8View("31"), Utf8View("23"), Utf8View("29"), Utf8View("30"), Utf8View("18"), Utf8View("17")])] +06)----------LeftAnti Join: customer.c_custkey = __correlated_sq_1.o_custkey +07)------------Filter: substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8View("13"), Utf8View("31"), Utf8View("23"), Utf8View("29"), Utf8View("30"), Utf8View("18"), Utf8View("17")]) AND CAST(customer.c_acctbal AS Decimal128(19, 6)) > () +08)--------------Subquery: +09)----------------Aggregate: groupBy=[[]], aggr=[[avg(customer.c_acctbal)]] +10)------------------Projection: customer.c_acctbal +11)--------------------Filter: customer.c_acctbal > Decimal128(Some(0),15,2) AND substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8View("13"), Utf8View("31"), Utf8View("23"), Utf8View("29"), Utf8View("30"), Utf8View("18"), Utf8View("17")]) +12)----------------------TableScan: customer projection=[c_phone, c_acctbal], partial_filters=[customer.c_acctbal > Decimal128(Some(0),15,2), substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8View("13"), Utf8View("31"), Utf8View("23"), Utf8View("29"), Utf8View("30"), Utf8View("18"), Utf8View("17")])] +13)--------------TableScan: customer projection=[c_custkey, c_phone, c_acctbal], partial_filters=[substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8View("13"), Utf8View("31"), Utf8View("23"), Utf8View("29"), Utf8View("30"), Utf8View("18"), Utf8View("17")])] +14)------------SubqueryAlias: __correlated_sq_1 +15)--------------TableScan: orders projection=[o_custkey] physical_plan -01)SortPreservingMergeExec: [cntrycode@0 ASC NULLS LAST] -02)--SortExec: expr=[cntrycode@0 ASC NULLS LAST], preserve_partitioning=[true] -03)----ProjectionExec: expr=[cntrycode@0 as cntrycode, count(Int64(1))@1 as numcust, sum(custsale.c_acctbal)@2 as totacctbal] -04)------AggregateExec: mode=FinalPartitioned, gby=[cntrycode@0 as cntrycode], aggr=[count(Int64(1)), sum(custsale.c_acctbal)] -05)--------RepartitionExec: partitioning=Hash([cntrycode@0], 4), input_partitions=4 -06)----------AggregateExec: mode=Partial, gby=[cntrycode@0 as cntrycode], aggr=[count(Int64(1)), sum(custsale.c_acctbal)] -07)------------ProjectionExec: expr=[substr(c_phone@0, 1, 2) as cntrycode, c_acctbal@1 as c_acctbal] -08)--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -09)----------------NestedLoopJoinExec: join_type=Inner, filter=join_proj_push_down_1@1 > avg(customer.c_acctbal)@0, projection=[c_phone@0, c_acctbal@1, avg(customer.c_acctbal)@3] -10)------------------ProjectionExec: expr=[c_phone@0 as c_phone, c_acctbal@1 as c_acctbal, CAST(c_acctbal@1 AS Decimal128(19, 6)) as join_proj_push_down_1] -11)--------------------CoalescePartitionsExec -12)----------------------HashJoinExec: mode=Partitioned, join_type=LeftAnti, on=[(c_custkey@0, o_custkey@0)], projection=[c_phone@1, c_acctbal@2] -13)------------------------RepartitionExec: partitioning=Hash([c_custkey@0], 4), input_partitions=4 -14)--------------------------FilterExec: substr(c_phone@1, 1, 2) IN (SET) ([13, 31, 23, 29, 30, 18, 17]) -15)----------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -16)------------------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/customer.tbl]]}, projection=[c_custkey, c_phone, c_acctbal], file_type=csv, has_header=false -17)------------------------RepartitionExec: partitioning=Hash([o_custkey@0], 4), input_partitions=4 -18)--------------------------DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:0..4223281], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:4223281..8446562], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:8446562..12669843], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:12669843..16893122]]}, projection=[o_custkey], file_type=csv, has_header=false -19)------------------AggregateExec: mode=Final, gby=[], aggr=[avg(customer.c_acctbal)] -20)--------------------CoalescePartitionsExec -21)----------------------AggregateExec: mode=Partial, gby=[], aggr=[avg(customer.c_acctbal)] -22)------------------------FilterExec: c_acctbal@1 > Some(0),15,2 AND substr(c_phone@0, 1, 2) IN (SET) ([13, 31, 23, 29, 30, 18, 17]), projection=[c_acctbal@1] -23)--------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -24)----------------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/customer.tbl]]}, projection=[c_phone, c_acctbal], file_type=csv, has_header=false +01)ScalarSubqueryExec: subqueries=1 +02)--RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1, maintains_sort_order=true +03)----SortPreservingMergeExec: [cntrycode@0 ASC NULLS LAST] +04)------SortExec: expr=[cntrycode@0 ASC NULLS LAST], preserve_partitioning=[true] +05)--------ProjectionExec: expr=[cntrycode@0 as cntrycode, count(Int64(1))@1 as numcust, sum(custsale.c_acctbal)@2 as totacctbal] +06)----------AggregateExec: mode=FinalPartitioned, gby=[cntrycode@0 as cntrycode], aggr=[count(Int64(1)), sum(custsale.c_acctbal)] +07)------------RepartitionExec: partitioning=Hash([cntrycode@0], 4), input_partitions=4 +08)--------------AggregateExec: mode=Partial, gby=[cntrycode@0 as cntrycode], aggr=[count(Int64(1)), sum(custsale.c_acctbal)] +09)----------------ProjectionExec: expr=[substr(c_phone@0, 1, 2) as cntrycode, c_acctbal@1 as c_acctbal] +10)------------------HashJoinExec: mode=Partitioned, join_type=LeftAnti, on=[(c_custkey@0, o_custkey@0)], projection=[c_phone@1, c_acctbal@2] +11)--------------------RepartitionExec: partitioning=Hash([c_custkey@0], 4), input_partitions=4 +12)----------------------FilterExec: substr(c_phone@1, 1, 2) IN (SET) ([13, 31, 23, 29, 30, 18, 17]) AND CAST(c_acctbal@2 AS Decimal128(19, 6)) > scalar_subquery() +13)------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +14)--------------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/customer.tbl]]}, projection=[c_custkey, c_phone, c_acctbal], file_type=csv, has_header=false +15)--------------------RepartitionExec: partitioning=Hash([o_custkey@0], 4), input_partitions=4 +16)----------------------DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:0..4223281], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:4223281..8446562], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:8446562..12669843], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:12669843..16893122]]}, projection=[o_custkey], file_type=csv, has_header=false +17)--AggregateExec: mode=Final, gby=[], aggr=[avg(customer.c_acctbal)] +18)----CoalescePartitionsExec +19)------AggregateExec: mode=Partial, gby=[], aggr=[avg(customer.c_acctbal)] +20)--------FilterExec: c_acctbal@1 > Some(0),15,2 AND substr(c_phone@0, 1, 2) IN (SET) ([13, 31, 23, 29, 30, 18, 17]), projection=[c_acctbal@1] +21)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +22)------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/customer.tbl]]}, projection=[c_phone, c_acctbal], file_type=csv, has_header=false diff --git a/docs/source/library-user-guide/upgrading/54.0.0.md b/docs/source/library-user-guide/upgrading/54.0.0.md index cadd365e1f814..315418dec5cf2 100644 --- a/docs/source/library-user-guide/upgrading/54.0.0.md +++ b/docs/source/library-user-guide/upgrading/54.0.0.md @@ -320,6 +320,92 @@ The difference is only observable for strings containing combining characters clusters (e.g., ZWJ emoji sequences). For ASCII and most common Unicode text, behavior is unchanged. +### Scalar subquery execution changes + +Uncorrelated scalar subqueries (e.g. `SELECT ... WHERE x > (SELECT max(v) FROM t)`) +are now executed by a dedicated physical operator rather than being rewritten to +a join. Correlated scalar subqueries are unchanged. + +This produces two user-visible changes: + +- **Subqueries that return multiple rows now fail at runtime.** An uncorrelated + scalar subquery that returns more than one row fails with `Execution error: Scalar subquery returned more than one row`. This matches the SQL standard and + the behavior of most other SQL implementations. The previous join-based + rewrite could silently produce multi-row output. Add a `LIMIT 1` or an + aggregate to the subquery to fix such queries. +- **Plan shape changes.** Uncorrelated `Expr::ScalarSubquery` nodes now survive + into the final logical plan instead of being replaced by a join; the + corresponding physical plan contains a new `ScalarSubqueryExec` node and a + `ScalarSubqueryExpr` expression. Code that walks or transforms `LogicalPlan` / + `ExecutionPlan` trees, as well as `EXPLAIN` output, may need updating. + +### `datafusion-proto`: expression deserialization now takes a `TaskContext` + +`Serializeable::from_bytes_with_registry` is renamed to `from_bytes_with_ctx` +and takes a `&TaskContext` instead of a `&dyn FunctionRegistry`. `parse_expr`, +`parse_exprs`, and `parse_sorts` take the same change. `Expr::from_bytes` +(without a registry argument) is unchanged. + +```diff +-let expr = Expr::from_bytes_with_registry(&bytes, ®istry)?; ++let expr = Expr::from_bytes_with_ctx(&bytes, ctx.task_ctx().as_ref())?; +``` + +```diff +-let expr = parse_expr(&proto, ®istry, &codec)?; ++let expr = parse_expr(&proto, ctx.task_ctx().as_ref(), &codec)?; +``` + +### `datafusion-proto`: `PhysicalProtoConverterExtension` reshaped + +`PhysicalProtoConverterExtension` and the `parse_physical_*_with_converter` +helpers now take a single `&PhysicalPlanDecodeContext<'_>` that bundles the +`TaskContext` and the `PhysicalExtensionCodec`. Implementations update like +this: + +```diff + impl PhysicalProtoConverterExtension for MyConverter { + fn proto_to_execution_plan( + &self, +- ctx: &TaskContext, +- codec: &dyn PhysicalExtensionCodec, + proto: &protobuf::PhysicalPlanNode, ++ ctx: &PhysicalPlanDecodeContext<'_>, + ) -> Result> { +- proto.try_into_physical_plan_with_converter(ctx, codec, self) ++ self.default_proto_to_execution_plan(proto, ctx) + } + + fn proto_to_physical_expr( + &self, + proto: &PhysicalExprNode, +- ctx: &TaskContext, + input_schema: &Schema, +- codec: &dyn PhysicalExtensionCodec, ++ ctx: &PhysicalPlanDecodeContext<'_>, + ) -> Result> { +- parse_physical_expr_with_converter(proto, ctx, input_schema, codec, self) ++ self.default_proto_to_physical_expr(proto, input_schema, ctx) + } + } +``` + +Pull out the `TaskContext` or codec inside these methods with +`ctx.task_ctx()` and `ctx.codec()`. Construct a fresh context at an API +boundary with `PhysicalPlanDecodeContext::new(task_ctx, codec)`. + +### `ExecutionProps` has new fields + +`ExecutionProps` gained new public fields. Code that constructs it via a +struct literal, or pattern-matches it without `..`, no longer compiles. Use +`ExecutionProps::new()` and include `..` in exhaustive patterns. + +### Wire format: scalar subquery serialization + +`datafusion-proto` adds new oneof variants to serialize scalar subqueries. +Plans produced by DataFusion 54 that contain scalar subqueries cannot be +decoded by older versions — upgrade producers and consumers together. + ### Items in `datafusion_functions::strings` are no longer public `StringArrayBuilder`, `LargeStringArrayBuilder`, `StringViewArrayBuilder`,