Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2482,6 +2482,7 @@ impl DataFrame {
relation: None,
name: field.name().to_string(),
metadata: None,
is_internal: false,
}),
Err(_) => col(field.name()),
}
Expand Down
264 changes: 245 additions & 19 deletions datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2383,10 +2383,10 @@ pub fn create_window_expr(
) -> Result<Arc<dyn WindowExpr>> {
// unpack aliased logical expressions, e.g. "sum(col) over () as total"
let (name, e) = match e {
Expr::Alias(Alias { expr, name, .. }) => (name.clone(), expr.as_ref()),
_ => (e.schema_name().to_string(), e),
Expr::Alias(alias) => (alias.name.clone(), e.clone().unalias_nested().data),
_ => (e.schema_name().to_string(), e.clone()),
};
create_window_expr_with_name(e, name, logical_schema, execution_props)
create_window_expr_with_name(&e, name, logical_schema, execution_props)
}

type AggregateExprWithOptionalArgs = (
Expand All @@ -2401,7 +2401,7 @@ type AggregateExprWithOptionalArgs = (
pub fn create_aggregate_expr_with_name_and_maybe_filter(
e: &Expr,
name: Option<String>,
human_displan: String,
human_display: Option<String>,
logical_input_schema: &DFSchema,
physical_input_schema: &Schema,
execution_props: &ExecutionProps,
Expand Down Expand Up @@ -2445,16 +2445,17 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter(
execution_props,
)?;

let agg_expr =
let mut builder =
AggregateExprBuilder::new(func.to_owned(), physical_args.to_vec())
.order_by(order_bys.clone())
.schema(Arc::new(physical_input_schema.to_owned()))
.alias(name)
.human_display(human_displan)
.with_ignore_nulls(ignore_nulls)
.with_distinct(*distinct)
.build()
.map(Arc::new)?;
.with_distinct(*distinct);
if let Some(human_display) = human_display {
builder = builder.human_display(human_display);
}
let agg_expr = builder.build().map(Arc::new)?;

(agg_expr, filter, order_bys)
};
Expand All @@ -2472,20 +2473,28 @@ pub fn create_aggregate_expr_and_maybe_filter(
physical_input_schema: &Schema,
execution_props: &ExecutionProps,
) -> Result<AggregateExprWithOptionalArgs> {
// Unpack (potentially nested) aliased logical expressions, e.g. "sum(col) as total"
// Some functions like `count_all()` create internal aliases,
// Unwrap all alias layers to get to the underlying aggregate function
// Unpack (potentially nested) aliased logical expressions, e.g. "sum(col) as total".
// Physical explain prefers the lowered aggregate form, so unwrap all alias
// layers to recover the underlying aggregate function and then re-attach
// only the visible output alias.
let (name, human_display, e) = match e {
Expr::Alias(Alias { name, .. }) => {
Expr::Alias(alias) => {
let unaliased = e.clone().unalias_nested().data;
(Some(name.clone()), e.human_display().to_string(), unaliased)
let human_display = unaliased.human_display().to_string();
let human_display = if human_display.is_empty() || human_display == alias.name
{
alias.name.clone()
} else {
format!("{human_display} as {}", alias.name)
};
(Some(alias.name.clone()), Some(human_display), unaliased)
}
Expr::AggregateFunction(_) => (
Some(e.schema_name().to_string()),
e.human_display().to_string(),
Some(e.human_display().to_string()),
e.clone(),
),
_ => (None, String::default(), e.clone()),
_ => (None, None, e.clone()),
};

create_aggregate_expr_with_name_and_maybe_filter(
Expand Down Expand Up @@ -3106,6 +3115,7 @@ impl<'n> TreeNodeVisitor<'n> for InvariantChecker {
mod tests {
use std::cmp::Ordering;
use std::fmt::{self, Debug};
use std::mem::size_of_val;
use std::ops::{BitAnd, Not};

use super::*;
Expand All @@ -3121,19 +3131,22 @@ mod tests {
use crate::execution::session_state::SessionStateBuilder;
use arrow::array::{ArrayRef, DictionaryArray, Int32Array};
use arrow::datatypes::{DataType, Field, Int32Type};
use arrow_schema::SchemaRef;
use arrow_schema::{FieldRef, SchemaRef};
use datafusion_common::config::ConfigOptions;
use datafusion_common::{
DFSchemaRef, TableReference, ToDFSchema as _, assert_contains,
DFSchemaRef, ScalarValue, TableReference, ToDFSchema as _, assert_contains,
};
use datafusion_execution::TaskContext;
use datafusion_execution::runtime_env::RuntimeEnv;
use datafusion_expr::builder::subquery_alias;
use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
use datafusion_expr::{
LogicalPlanBuilder, TableSource, UserDefinedLogicalNodeCore, col, lit,
Accumulator, AggregateUDF, AggregateUDFImpl, ExprFunctionExt, LogicalPlanBuilder,
Signature, TableSource, UserDefinedLogicalNodeCore, Volatility, col, lit,
};
use datafusion_functions_aggregate::count::count_all;
use datafusion_functions_aggregate::expr_fn::sum;
use datafusion_functions_aggregate::first_last::first_value_udaf;
use datafusion_physical_expr::EquivalenceProperties;
use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType};

Expand All @@ -3158,6 +3171,83 @@ mod tests {
.await
}

async fn aggregate_explain(logical_plan: &LogicalPlan) -> Result<String> {
let physical_plan = plan(logical_plan).await?;
Ok(displayable(physical_plan.as_ref()).indent(true).to_string())
}

#[derive(Debug, Default)]
struct NullAccumulator;

impl Accumulator for NullAccumulator {
fn state(&mut self) -> Result<Vec<ScalarValue>> {
Ok(vec![self.evaluate()?])
}

fn update_batch(&mut self, _values: &[ArrayRef]) -> Result<()> {
Ok(())
}

fn merge_batch(&mut self, _states: &[ArrayRef]) -> Result<()> {
Ok(())
}

fn evaluate(&mut self) -> Result<ScalarValue> {
Ok(ScalarValue::Int64(None))
}

fn size(&self) -> usize {
size_of_val(self)
}
}

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
struct CustomHumanDisplayUdaf {
signature: Signature,
}

impl CustomHumanDisplayUdaf {
fn new() -> Self {
Self {
signature: Signature::exact(vec![DataType::Int64], Volatility::Immutable),
}
}
}

impl AggregateUDFImpl for CustomHumanDisplayUdaf {
fn name(&self) -> &str {
"custom_human_display_udaf"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Int64)
}

fn accumulator(
&self,
_acc_args: AccumulatorArgs,
) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(NullAccumulator))
}

fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
Ok(vec![
Field::new("custom_state", DataType::Int64, true).into(),
])
}

fn human_display(&self, params: &AggregateFunctionParams) -> Result<String> {
Ok(format!(
"custom_display({})",
params.args[0].human_display()
))
}
}

#[tokio::test]
async fn test_all_operators() -> Result<()> {
let logical_plan = test_csv_scan()
Expand Down Expand Up @@ -3769,6 +3859,142 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn test_aggregate_explain_shows_quoted_user_alias() -> Result<()> {
let schema = Arc::new(Schema::new(vec![Field::new(
"column1",
DataType::Int64,
false,
)]));

let logical_plan = scan_empty(None, schema.as_ref(), None)?
.aggregate(
Vec::<Expr>::new(),
vec![sum(col("column1")).alias("total rows")],
)?
.build()?;

assert_contains!(
aggregate_explain(&logical_plan).await?,
"AggregateExec: mode=Single, gby=[], aggr=[sum(?table?.column1) as total rows]"
);

Ok(())
}

#[tokio::test]
async fn test_aggregate_explain_shows_aliased_filter_expression() -> Result<()> {
let schema = Arc::new(Schema::new(vec![
Field::new("column1", DataType::Int64, false),
Field::new("column2", DataType::Int64, false),
]));

let logical_plan = scan_empty(None, schema.as_ref(), None)?
.aggregate(
Vec::<Expr>::new(),
vec![
sum(col("column1"))
.filter(col("column2").lt_eq(lit(0_i64)))
.build()?
.alias("agg"),
],
)?
.build()?;

assert_contains!(
aggregate_explain(&logical_plan).await?,
"AggregateExec: mode=Single, gby=[], aggr=[sum(?table?.column1) FILTER (WHERE ?table?.column2 <= Int64(0)) as agg]"
);

Ok(())
}

#[tokio::test]
async fn test_aggregate_explain_shows_aliased_respect_nulls() -> Result<()> {
let schema = Arc::new(Schema::new(vec![
Field::new("column1", DataType::Int64, true),
Field::new("column2", DataType::Int64, false),
]));

let logical_plan = scan_empty(None, schema.as_ref(), None)?
.aggregate(
Vec::<Expr>::new(),
vec![
first_value_udaf()
.call(vec![col("column1")])
.order_by(vec![col("column2").sort(true, true)])
.null_treatment(NullTreatment::RespectNulls)
.build()?
.alias("agg"),
],
)?
.build()?;

assert_contains!(
aggregate_explain(&logical_plan).await?,
"AggregateExec: mode=Single, gby=[], aggr=[first_value(?table?.column1) RESPECT NULLS ORDER BY [?table?.column2 ASC NULLS FIRST] as agg]"
);

Ok(())
}

#[tokio::test]
async fn test_aggregate_explain_shows_count_all() -> Result<()> {
let logical_plan = test_csv_scan()
.await?
.aggregate(Vec::<Expr>::new(), vec![count_all()])?
.build()?;

assert_contains!(
aggregate_explain(&logical_plan).await?,
"aggr=[count(1) as count(*)]"
);

Ok(())
}

#[tokio::test]
async fn test_aggregate_explain_shows_count_all_with_user_alias() -> Result<()> {
let logical_plan = test_csv_scan()
.await?
.aggregate(Vec::<Expr>::new(), vec![count_all().alias("total_rows")])?
.build()?;

assert_contains!(
aggregate_explain(&logical_plan).await?,
"aggr=[count(1) as total_rows]"
);

Ok(())
}

#[tokio::test]
async fn test_aggregate_explain_shows_aliased_custom_human_display() -> Result<()> {
let schema = Arc::new(Schema::new(vec![Field::new(
"column1",
DataType::Int64,
false,
)]));

let logical_plan = scan_empty(None, schema.as_ref(), None)?
.aggregate(
Vec::<Expr>::new(),
vec![
AggregateUDF::from(CustomHumanDisplayUdaf::new())
.call(vec![col("column1")])
.alias("agg"),
],
)?
.build()?;

assert_contains!(
aggregate_explain(&logical_plan).await?,
"AggregateExec: mode=Single, gby=[], aggr=[custom_display(?table?.column1) as agg]"
);

Ok(())
}

#[tokio::test]
async fn test_explain() {
let schema = Schema::new(vec![Field::new("id", DataType::Int32, false)]);
Expand Down
32 changes: 16 additions & 16 deletions datafusion/core/tests/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3022,20 +3022,20 @@ async fn test_count_wildcard_on_sort() -> Result<()> {
assert_snapshot!(
pretty_format_batches(&df_results).unwrap(),
@r"
+---------------+----------------------------------------------------------------------------+
| plan_type | plan |
+---------------+----------------------------------------------------------------------------+
| logical_plan | Sort: count(*) AS count(*) ASC NULLS LAST |
| | Aggregate: groupBy=[[t1.b]], aggr=[[count(Int64(1)) AS count(*)]] |
| | TableScan: t1 projection=[b] |
| physical_plan | SortPreservingMergeExec: [count(*)@1 ASC NULLS LAST] |
| | SortExec: expr=[count(*)@1 ASC NULLS LAST], preserve_partitioning=[true] |
| | AggregateExec: mode=FinalPartitioned, gby=[b@0 as b], aggr=[count(*)] |
| | RepartitionExec: partitioning=Hash([b@0], 4), input_partitions=1 |
| | AggregateExec: mode=Partial, gby=[b@0 as b], aggr=[count(*)] |
| | DataSourceExec: partitions=1, partition_sizes=[1] |
| | |
+---------------+----------------------------------------------------------------------------+
+---------------+---------------------------------------------------------------------------------------+
| plan_type | plan |
+---------------+---------------------------------------------------------------------------------------+
| logical_plan | Sort: count(*) AS count(*) ASC NULLS LAST |
| | Aggregate: groupBy=[[t1.b]], aggr=[[count(Int64(1)) AS count(*)]] |
| | TableScan: t1 projection=[b] |
| physical_plan | SortPreservingMergeExec: [count(*)@1 ASC NULLS LAST] |
| | SortExec: expr=[count(*)@1 ASC NULLS LAST], preserve_partitioning=[true] |
| | AggregateExec: mode=FinalPartitioned, gby=[b@0 as b], aggr=[count(1) as count(*)] |
| | RepartitionExec: partitioning=Hash([b@0], 4), input_partitions=1 |
| | AggregateExec: mode=Partial, gby=[b@0 as b], aggr=[count(1) as count(*)] |
| | DataSourceExec: partitions=1, partition_sizes=[1] |
| | |
+---------------+---------------------------------------------------------------------------------------+
"
);
Ok(())
Expand Down Expand Up @@ -3500,9 +3500,9 @@ async fn test_count_wildcard_on_where_scalar_subquery() -> Result<()> {
| | HashJoinExec: mode=CollectLeft, join_type=Right, on=[(a@1, a@0)], projection=[a@3, b@4, count(*)@0, __always_true@2] |
| | CoalescePartitionsExec |
| | ProjectionExec: expr=[count(*)@1 as count(*), a@0 as a, true as __always_true] |
| | AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[count(*)] |
| | AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[count(1) as count(*)] |
| | RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=1 |
| | AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[count(*)] |
| | AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[count(1) as count(*)] |
| | DataSourceExec: partitions=1, partition_sizes=[1] |
| | DataSourceExec: partitions=1, partition_sizes=[1] |
| | |
Expand Down
Loading
Loading