Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 203 additions & 7 deletions datafusion/physical-plan/src/joins/hash_join/exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ use crate::{
};

use arrow::array::{ArrayRef, BooleanBufferBuilder};
use arrow::compute::concat_batches;
use arrow::compute::{concat, concat_batches};
use arrow::datatypes::SchemaRef;
use arrow::record_batch::RecordBatch;
use arrow::util::bit_util;
Expand All @@ -87,6 +87,7 @@ use datafusion_physical_expr::equivalence::{
};
use datafusion_physical_expr::expressions::{Column, DynamicFilterPhysicalExpr, lit};
use datafusion_physical_expr::projection::{ProjectionRef, combine_projections};
use datafusion_physical_expr::utils::collect_columns;
use datafusion_physical_expr::{PhysicalExpr, PhysicalExprRef};

use datafusion_common::hash_utils::RandomState;
Expand All @@ -113,6 +114,7 @@ fn try_create_array_map(
perfect_hash_join_small_build_threshold: usize,
perfect_hash_join_min_key_density: f64,
null_equality: NullEquality,
build_side_projection: &Option<(Vec<usize>, Vec<usize>)>,
) -> Result<Option<(ArrayMap, RecordBatch, Vec<ArrayRef>)>> {
if on_left.len() != 1 {
return Ok(None);
Expand Down Expand Up @@ -178,8 +180,12 @@ fn try_create_array_map(
let mem_size = ArrayMap::estimate_memory_size(min_val, max_val, num_row);
reservation.try_grow(mem_size)?;

let batch = concat_batches(schema, batches)?;
let left_values = evaluate_expressions_to_arrays(on_left, &batch)?;
let (batch, left_values) = concat_and_evaluate_build_side(
on_left,
schema,
batches.iter(),
build_side_projection,
)?;

let array_map = ArrayMap::try_new(&left_values[0], min_val, max_val)?;

Expand Down Expand Up @@ -217,6 +223,9 @@ pub(super) struct JoinLeftData {
pub(super) probe_side_non_empty: AtomicBool,
/// Shared atomic flag indicating if any probe partition saw NULL in join keys (for null-aware anti joins)
pub(super) probe_side_has_null: AtomicBool,
/// Mapping from original build-side column index to projected column index.
/// `None` if no projection was applied (all columns kept).
build_column_remap: Option<Vec<usize>>,
}

impl JoinLeftData {
Expand Down Expand Up @@ -250,6 +259,11 @@ impl JoinLeftData {
pub(super) fn report_probe_completed(&self) -> bool {
self.probe_threads_counter.fetch_sub(1, Ordering::Relaxed) == 1
}

/// Returns the build-side column remap table, if projection was applied.
pub(super) fn build_column_remap(&self) -> Option<&[usize]> {
self.build_column_remap.as_deref()
}
}

/// Helps to build [`HashJoinExec`].
Expand Down Expand Up @@ -1316,6 +1330,17 @@ impl ExecutionPlan for HashJoinExec {
.with_category(MetricCategory::Rows)
.counter(ARRAY_MAP_CREATED_COUNT_METRIC_NAME, partition);

// Compute which build-side columns are actually needed.
// column_indices_after_projection is computed later, so use the full
// column_indices here (projection only narrows further, never adds).
let build_side_projection = compute_build_side_projection(
&on_left,
&self.column_indices,
self.filter.as_ref(),
JoinSide::Left,
self.left.schema().fields().len(),
);

let left_fut = match self.mode {
PartitionMode::CollectLeft => self.left_fut.try_once(|| {
let left_stream = self.left.execute(0, Arc::clone(&context))?;
Expand All @@ -1335,6 +1360,7 @@ impl ExecutionPlan for HashJoinExec {
Arc::clone(context.session_config().options()),
self.null_equality,
array_map_created_count,
build_side_projection.clone(),
))
})?,
PartitionMode::Partitioned => {
Expand All @@ -1356,6 +1382,7 @@ impl ExecutionPlan for HashJoinExec {
Arc::clone(context.session_config().options()),
self.null_equality,
array_map_created_count,
build_side_projection.clone(),
))
}
PartitionMode::Auto => {
Expand Down Expand Up @@ -1855,6 +1882,170 @@ fn should_collect_min_max_for_perfect_hash(
Ok(ArrayMap::is_supported_type(&data_type))
}

/// Evaluates join key expressions on each batch individually, then concatenates
/// the per-batch result arrays into flat arrays. This allows projecting the
/// original batches to fewer columns before the expensive `concat_batches`,
/// since join key expressions may reference columns not needed in the output.
///
/// # Example
///
/// Given `on_left = [col("customer_id")]` and two batches:
///
/// ```text
/// Batch 0: { order_id: [1,2], customer_id: [10,20], notes: ["a","b"] }
/// Batch 1: { order_id: [3], customer_id: [30], notes: ["c"] }
/// ```
///
/// Step 1 — evaluate `col("customer_id")` on each batch:
/// ```text
/// Batch 0 → [10, 20]
/// Batch 1 → [30]
/// ```
///
/// Step 2 — concat per-key arrays:
/// ```text
/// key 0 (customer_id) → [10, 20, 30]
/// ```
///
/// Returns `vec![ [10, 20, 30] ]` — the flat join key arrays, without ever
/// needing a full `concat_batches` of all columns.
fn evaluate_and_concat_per_batch<'a>(
on_left: &[PhysicalExprRef],
batches: impl Iterator<Item = &'a RecordBatch> + Clone,
) -> Result<Vec<ArrayRef>> {
let mut per_key_arrays: Vec<Vec<ArrayRef>> = vec![Vec::new(); on_left.len()];
for batch in batches {
if batch.num_rows() == 0 {
continue;
}
let arrays = evaluate_expressions_to_arrays(on_left, batch)?;
for (i, arr) in arrays.into_iter().enumerate() {
per_key_arrays[i].push(arr);
}
}
per_key_arrays
.into_iter()
.map(|arrs| {
if arrs.is_empty() {
// No rows — return empty array of the right type.
// This shouldn't happen since we check num_rows > 0 in callers,
// but handle gracefully.
Ok(arrow::array::new_empty_array(&DataType::Null))
} else {
let refs: Vec<&dyn arrow::array::Array> =
arrs.iter().map(|a| a.as_ref()).collect();
Ok(concat(&refs)?)
}
})
.collect()
}

/// Concatenates build-side batches and evaluates join key expressions,
/// optionally projecting to only the needed columns first.
///
/// When `build_side_projection` is `Some`, evaluates join keys per-batch
/// (before projection removes columns they may reference), then projects
/// and concatenates only the needed columns. When `None`, uses the original
/// path: concat all columns, evaluate once on the result.
fn concat_and_evaluate_build_side<'a>(
on_left: &[PhysicalExprRef],
schema: &SchemaRef,
batches: impl Iterator<Item = &'a RecordBatch> + Clone,
build_side_projection: &Option<(Vec<usize>, Vec<usize>)>,
) -> Result<(RecordBatch, Vec<ArrayRef>)> {
if let Some((proj_indices, _)) = build_side_projection {
let left_values = evaluate_and_concat_per_batch(on_left, batches.clone())?;
let projected_schema = Arc::new(schema.project(proj_indices)?);
let projected: Vec<RecordBatch> = batches
.map(|b| b.project(proj_indices))
.collect::<Result<_, _>>()?;
let batch = concat_batches(&projected_schema, &projected)?;
Ok((batch, left_values))
} else {
let batch = concat_batches(schema, batches)?;
let left_values = evaluate_expressions_to_arrays(on_left, &batch)?;
Ok((batch, left_values))
}
}

/// Determines which build-side columns are actually needed for the hash join
/// output, filter evaluation, and join key computation.
///
/// Returns `None` if all columns are needed (no projection benefit),
/// or `Some((projected_indices, remap))` where:
/// - `projected_indices`: sorted column indices to keep from the original schema
/// - `remap`: maps original column index → new projected index
fn compute_build_side_projection(
on_left: &[PhysicalExprRef],
output_column_indices: &[ColumnIndex],
filter: Option<&JoinFilter>,
build_side: JoinSide,
num_build_columns: usize,
) -> Option<(Vec<usize>, Vec<usize>)> {
let mut needed: HashSet<usize> = HashSet::new();

// 1. Columns referenced by join key expressions
for expr in on_left {
for col in collect_columns(expr) {
needed.insert(col.index());
}
}

// 2. Columns referenced by output column_indices (build side only)
for ci in output_column_indices {
if ci.side == build_side {
needed.insert(ci.index);
}
}

// 3. Columns referenced by join filter (build side only)
if let Some(f) = filter {
for ci in f.column_indices() {
if ci.side == build_side {
needed.insert(ci.index);
}
}
}

// Short-circuit: if all columns needed, skip projection
if needed.len() >= num_build_columns {
return None;
}

let mut projected_indices: Vec<usize> = needed.into_iter().collect();
projected_indices.sort_unstable();

// Build remap: original_index → projected_index
let mut remap = vec![0usize; num_build_columns];
for (new_idx, &orig_idx) in projected_indices.iter().enumerate() {
remap[orig_idx] = new_idx;
}

Some((projected_indices, remap))
}

/// Remaps build-side column indices in a `ColumnIndex` slice using the given remap table.
/// Right-side (probe) indices are left unchanged.
pub(super) fn remap_column_indices(
column_indices: &[ColumnIndex],
remap: &[usize],
build_side: JoinSide,
) -> Vec<ColumnIndex> {
column_indices
.iter()
.map(|ci| {
if ci.side == build_side {
ColumnIndex {
index: remap[ci.index],
side: ci.side,
}
} else {
ci.clone()
}
})
.collect()
}

/// Collects all batches from the left (build) side stream and creates a hash map for joining.
///
/// This function is responsible for:
Expand Down Expand Up @@ -1896,6 +2087,7 @@ async fn collect_left_input(
config: Arc<ConfigOptions>,
null_equality: NullEquality,
array_map_created_count: Count,
build_side_projection: Option<(Vec<usize>, Vec<usize>)>,
) -> Result<JoinLeftData> {
let schema = left_stream.schema();

Expand Down Expand Up @@ -1966,6 +2158,7 @@ async fn collect_left_input(
config.execution.perfect_hash_join_small_build_threshold,
config.execution.perfect_hash_join_min_key_density,
null_equality,
&build_side_projection,
)? {
array_map_created_count.add(1);
metrics.build_mem_used.add(array_map.size());
Expand Down Expand Up @@ -2016,10 +2209,12 @@ async fn collect_left_input(
offset += batch.num_rows();
}

// Merge all batches into a single batch, so we can directly index into the arrays
let batch = concat_batches(&schema, batches_iter.clone())?;

let left_values = evaluate_expressions_to_arrays(&on_left, &batch)?;
let (batch, left_values) = concat_and_evaluate_build_side(
&on_left,
&schema,
batches_iter,
&build_side_projection,
)?;

(Map::HashMap(hashmap), batch, left_values)
};
Expand Down Expand Up @@ -2080,6 +2275,7 @@ async fn collect_left_input(
membership,
probe_side_non_empty: AtomicBool::new(false),
probe_side_has_null: AtomicBool::new(false),
build_column_remap: build_side_projection.map(|(_, remap)| remap),
};

Ok(data)
Expand Down
25 changes: 25 additions & 0 deletions datafusion/physical-plan/src/joins/hash_join/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,31 @@ impl HashJoinStream {
}

self.build_side = BuildSide::Ready(BuildSideReadyState { left_data });

// If the build side was projected to fewer columns, remap column indices
// so they reference the projected batch positions instead of the original schema.
if let BuildSide::Ready(ref ready) = self.build_side
&& let Some(remap) = ready.left_data.build_column_remap()
{
self.column_indices = super::exec::remap_column_indices(
&self.column_indices,
remap,
JoinSide::Left,
);
if let Some(ref filter) = self.filter {
let remapped_filter_indices = super::exec::remap_column_indices(
filter.column_indices(),
remap,
JoinSide::Left,
);
self.filter = Some(JoinFilter::new(
Arc::clone(filter.expression()),
remapped_filter_indices,
Arc::clone(filter.schema()),
));
}
}

Poll::Ready(Ok(StatefulStreamResult::Continue))
}

Expand Down
Loading