Skip to content

Commit a553320

Browse files
committed
add infeasible flag equality
1 parent d2b0fc3 commit a553320

1 file changed

Lines changed: 163 additions & 23 deletions

File tree

datafusion/physical-plan/src/filter.rs

Lines changed: 163 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
use std::collections::HashMap;
1819
use std::pin::Pin;
1920
use std::sync::Arc;
2021
use std::task::{Context, Poll, ready};
@@ -320,12 +321,16 @@ impl FilterExec {
320321
predicate: &Arc<dyn PhysicalExpr>,
321322
default_selectivity: u8,
322323
) -> Result<Statistics> {
323-
let eq_columns = collect_equality_columns(predicate);
324+
let (eq_columns, is_infeasible) = collect_equality_columns(predicate);
324325

325326
let num_rows = input_stats.num_rows;
326327
let total_byte_size = input_stats.total_byte_size;
327328

328-
let (selectivity, mut column_statistics) = if !check_support(predicate, schema) {
329+
let (selectivity, mut column_statistics) = if is_infeasible {
330+
// Contradictory equality predicates (e.g. `a = 1 AND a = 2`)
331+
// can never be satisfied.
332+
(0.0, input_stats.to_inexact().column_statistics)
333+
} else if !check_support(predicate, schema) {
329334
(
330335
default_selectivity as f64 / 100.0,
331336
input_stats.to_inexact().column_statistics,
@@ -350,11 +355,17 @@ impl FilterExec {
350355
let num_rows = num_rows.with_estimated_selectivity(selectivity);
351356
let total_byte_size = total_byte_size.with_estimated_selectivity(selectivity);
352357

353-
for idx in &eq_columns {
354-
if *idx < column_statistics.len()
355-
&& column_statistics[*idx].distinct_count != Precision::Exact(0)
356-
{
357-
column_statistics[*idx].distinct_count = Precision::Exact(1);
358+
if is_infeasible {
359+
for col_stat in &mut column_statistics {
360+
col_stat.distinct_count = Precision::Exact(0);
361+
}
362+
} else {
363+
for idx in eq_columns.keys() {
364+
if *idx < column_statistics.len()
365+
&& column_statistics[*idx].distinct_count != Precision::Exact(0)
366+
{
367+
column_statistics[*idx].distinct_count = Precision::Exact(1);
368+
}
358369
}
359370
}
360371

@@ -770,12 +781,23 @@ impl EmbeddedProjection for FilterExec {
770781
}
771782
}
772783

773-
/// Returns column indices constrained to a single value by `col = literal`
774-
/// equality predicates in a conjunction. Only recurses through AND (via
775-
/// `split_conjunction`); OR is intentionally not traversed since
776-
/// `a = 1 OR a = 2` does not pin NDV to 1.
777-
fn collect_equality_columns(predicate: &Arc<dyn PhysicalExpr>) -> Vec<usize> {
778-
let mut eq_columns = Vec::new();
784+
/// Collects column equality information from `col = literal` predicates in a
785+
/// conjunction.
786+
///
787+
/// Returns `(eq_columns, is_infeasible)`:
788+
/// - `eq_columns`: column indices constrained to a single value.
789+
/// - `is_infeasible`: `true` when the same column is equated to two different
790+
/// non-null literals (e.g. `name = 'alice' AND name = 'bob'`), which is
791+
/// always unsatisfiable.
792+
///
793+
/// Only recurses through AND (via `split_conjunction`); OR is intentionally
794+
/// not traversed since `a = 1 OR a = 2` does not pin NDV to 1.
795+
fn collect_equality_columns(
796+
predicate: &Arc<dyn PhysicalExpr>,
797+
) -> (HashMap<usize, ScalarValue>, bool) {
798+
let mut eq_columns: HashMap<usize, ScalarValue> = HashMap::new();
799+
let mut infeasible = false;
800+
779801
for expr in split_conjunction(predicate) {
780802
let Some(binary) = expr.as_any().downcast_ref::<BinaryExpr>() else {
781803
continue;
@@ -785,17 +807,32 @@ fn collect_equality_columns(predicate: &Arc<dyn PhysicalExpr>) -> Vec<usize> {
785807
}
786808
let left = binary.left();
787809
let right = binary.right();
788-
if let Some(col) = left.as_any().downcast_ref::<Column>()
789-
&& right.as_any().is::<Literal>()
810+
let pair = if let Some(col) = left.as_any().downcast_ref::<Column>()
811+
&& let Some(lit) = right.as_any().downcast_ref::<Literal>()
812+
&& !lit.value().is_null()
790813
{
791-
eq_columns.push(col.index());
814+
Some((col.index(), lit.value().clone()))
792815
} else if let Some(col) = right.as_any().downcast_ref::<Column>()
793-
&& left.as_any().is::<Literal>()
816+
&& let Some(lit) = left.as_any().downcast_ref::<Literal>()
817+
&& !lit.value().is_null()
794818
{
795-
eq_columns.push(col.index());
819+
Some((col.index(), lit.value().clone()))
820+
} else {
821+
None
822+
};
823+
824+
if let Some((idx, value)) = pair {
825+
if let Some(prev) = eq_columns.get(&idx) {
826+
if *prev != value {
827+
infeasible = true;
828+
}
829+
} else {
830+
eq_columns.insert(idx, value);
831+
}
796832
}
797833
}
798-
eq_columns
834+
835+
(eq_columns, infeasible)
799836
}
800837

801838
/// Converts an interval bound to a [`Precision`] value. NULL bounds (which
@@ -2492,6 +2529,32 @@ mod tests {
24922529
)),
24932530
vec![Precision::Exact(1)],
24942531
),
2532+
(
2533+
"contradictory utf8 equality (infeasible)",
2534+
vec![Field::new("name", DataType::Utf8, false)],
2535+
vec![ColumnStatistics {
2536+
distinct_count: Precision::Inexact(100),
2537+
..Default::default()
2538+
}],
2539+
Arc::new(BinaryExpr::new(
2540+
Arc::new(BinaryExpr::new(
2541+
Arc::new(Column::new("name", 0)),
2542+
Operator::Eq,
2543+
Arc::new(Literal::new(ScalarValue::Utf8(Some(
2544+
"alice".to_string(),
2545+
)))),
2546+
)),
2547+
Operator::And,
2548+
Arc::new(BinaryExpr::new(
2549+
Arc::new(Column::new("name", 0)),
2550+
Operator::Eq,
2551+
Arc::new(Literal::new(ScalarValue::Utf8(Some(
2552+
"bob".to_string(),
2553+
)))),
2554+
)),
2555+
)),
2556+
vec![Precision::Exact(0)],
2557+
),
24952558
];
24962559

24972560
for (desc, fields, col_stats, predicate, expected_ndvs) in cases {
@@ -2520,7 +2583,10 @@ mod tests {
25202583

25212584
#[test]
25222585
fn test_collect_equality_columns() {
2523-
let cases: Vec<(&str, Arc<dyn PhysicalExpr>, Vec<usize>)> = vec![
2586+
use std::collections::HashSet;
2587+
// (description, predicate, expected_column_indices, expected_infeasible)
2588+
#[expect(clippy::type_complexity)]
2589+
let cases: Vec<(&str, Arc<dyn PhysicalExpr>, Vec<usize>, bool)> = vec![
25242590
(
25252591
"simple col = literal",
25262592
Arc::new(BinaryExpr::new(
@@ -2529,6 +2595,7 @@ mod tests {
25292595
Arc::new(Literal::new(ScalarValue::Int32(Some(42)))),
25302596
)),
25312597
vec![0],
2598+
false,
25322599
),
25332600
(
25342601
"reversed literal = col",
@@ -2538,6 +2605,7 @@ mod tests {
25382605
Arc::new(Column::new("a", 0)),
25392606
)),
25402607
vec![0],
2608+
false,
25412609
),
25422610
(
25432611
"AND with two equalities",
@@ -2557,6 +2625,7 @@ mod tests {
25572625
)),
25582626
)),
25592627
vec![0, 1],
2628+
false,
25602629
),
25612630
(
25622631
"OR produces empty set",
@@ -2574,6 +2643,7 @@ mod tests {
25742643
)),
25752644
)),
25762645
vec![],
2646+
false,
25772647
),
25782648
(
25792649
"greater-than produces empty set",
@@ -2583,6 +2653,7 @@ mod tests {
25832653
Arc::new(Literal::new(ScalarValue::Int32(Some(42)))),
25842654
)),
25852655
vec![],
2656+
false,
25862657
),
25872658
(
25882659
"col = col produces empty set",
@@ -2592,6 +2663,7 @@ mod tests {
25922663
Arc::new(Column::new("b", 1)),
25932664
)),
25942665
vec![],
2666+
false,
25952667
),
25962668
(
25972669
"nested AND with three equalities",
@@ -2617,6 +2689,7 @@ mod tests {
26172689
)),
26182690
)),
26192691
vec![0, 1, 2],
2692+
false,
26202693
),
26212694
(
26222695
"AND with mixed equality and non-equality",
@@ -2634,12 +2707,79 @@ mod tests {
26342707
)),
26352708
)),
26362709
vec![0],
2710+
false,
2711+
),
2712+
(
2713+
"col = NULL is excluded",
2714+
Arc::new(BinaryExpr::new(
2715+
Arc::new(Column::new("a", 0)),
2716+
Operator::Eq,
2717+
Arc::new(Literal::new(ScalarValue::Int32(None))),
2718+
)),
2719+
vec![],
2720+
false,
2721+
),
2722+
(
2723+
"NULL = col is excluded",
2724+
Arc::new(BinaryExpr::new(
2725+
Arc::new(Literal::new(ScalarValue::Utf8(None))),
2726+
Operator::Eq,
2727+
Arc::new(Column::new("a", 0)),
2728+
)),
2729+
vec![],
2730+
false,
2731+
),
2732+
(
2733+
"contradictory: same col, different literals",
2734+
Arc::new(BinaryExpr::new(
2735+
Arc::new(BinaryExpr::new(
2736+
Arc::new(Column::new("a", 0)),
2737+
Operator::Eq,
2738+
Arc::new(Literal::new(ScalarValue::Utf8(Some(
2739+
"alice".to_string(),
2740+
)))),
2741+
)),
2742+
Operator::And,
2743+
Arc::new(BinaryExpr::new(
2744+
Arc::new(Column::new("a", 0)),
2745+
Operator::Eq,
2746+
Arc::new(Literal::new(ScalarValue::Utf8(Some(
2747+
"bob".to_string(),
2748+
)))),
2749+
)),
2750+
)),
2751+
vec![0],
2752+
true,
2753+
),
2754+
(
2755+
"same col, same literal is not contradictory",
2756+
Arc::new(BinaryExpr::new(
2757+
Arc::new(BinaryExpr::new(
2758+
Arc::new(Column::new("a", 0)),
2759+
Operator::Eq,
2760+
Arc::new(Literal::new(ScalarValue::Int32(Some(42)))),
2761+
)),
2762+
Operator::And,
2763+
Arc::new(BinaryExpr::new(
2764+
Arc::new(Column::new("a", 0)),
2765+
Operator::Eq,
2766+
Arc::new(Literal::new(ScalarValue::Int32(Some(42)))),
2767+
)),
2768+
)),
2769+
vec![0],
2770+
false,
26372771
),
26382772
];
26392773

2640-
for (desc, expr, expected) in cases {
2641-
let result = collect_equality_columns(&expr);
2642-
assert_eq!(result, expected, "case '{desc}': mismatch");
2774+
for (desc, expr, expected_cols, expected_infeasible) in cases {
2775+
let (result, infeasible) = collect_equality_columns(&expr);
2776+
let result_keys: HashSet<usize> = result.keys().copied().collect();
2777+
let expected: HashSet<usize> = expected_cols.into_iter().collect();
2778+
assert_eq!(result_keys, expected, "case '{desc}': columns mismatch");
2779+
assert_eq!(
2780+
infeasible, expected_infeasible,
2781+
"case '{desc}': infeasible mismatch"
2782+
);
26432783
}
26442784
}
26452785
}

0 commit comments

Comments
 (0)