diff --git a/datafusion-examples/examples/custom_data_source/literal_guarantee_filter_pushdown.rs b/datafusion-examples/examples/custom_data_source/literal_guarantee_filter_pushdown.rs new file mode 100644 index 000000000000..c76b42b4b1ee --- /dev/null +++ b/datafusion-examples/examples/custom_data_source/literal_guarantee_filter_pushdown.rs @@ -0,0 +1,394 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! See `main.rs` for how to run it. +//! +//! Example demonstrating filter pushdown using [`LiteralGuarantee`] with a +//! custom index. +//! +//! A [`TableProvider`] with a `BTreeMap` index on `bank_account` uses +//! [`LiteralGuarantee::analyze`] to extract equality / IN-list constraints +//! from pushed-down filter predicates and resolves them via the index, +//! returning only the matching rows. + +use std::any::Any; +use std::collections::{BTreeMap, HashMap}; +use std::fmt::{self, Debug, Formatter}; +use std::sync::{Arc, Mutex}; + +use async_trait::async_trait; +use datafusion::arrow::array::{UInt8Builder, UInt64Builder}; +use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use datafusion::arrow::record_batch::RecordBatch; +use datafusion::catalog::Session; +use datafusion::common::tree_node::TreeNodeRecursion; +use datafusion::datasource::{TableProvider, TableType}; +use datafusion::error::Result; +use datafusion::execution::context::TaskContext; +use datafusion::physical_expr::EquivalenceProperties; +use datafusion::physical_expr::planner::logical2physical; +use datafusion::physical_expr::utils::Guarantee; +use datafusion::physical_expr::utils::LiteralGuarantee; +use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType}; +use datafusion::physical_plan::memory::MemoryStream; +use datafusion::physical_plan::{ + DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, + SendableRecordBatchStream, +}; +use datafusion::prelude::*; +use datafusion::scalar::ScalarValue; +use datafusion_expr::TableProviderFilterPushDown; +use datafusion_expr::utils::conjunction; + +pub async fn literal_guarantee_filter_pushdown() -> Result<()> { + let db = CustomDataSource::default(); + db.populate_users(); + + let ctx = SessionContext::new(); + ctx.register_table("accounts", Arc::new(db))?; + + // No filter — returns all 3 rows. + let all = ctx.sql("SELECT * FROM accounts").await?.collect().await?; + assert_eq!(all.iter().map(|b| b.num_rows()).sum::(), 3); + println!("All rows:"); + for batch in &all { + println!("{batch:?}"); + } + + // Equality filter — the index resolves exactly one row. + let one = ctx + .sql("SELECT * FROM accounts WHERE bank_account = 9000") + .await? + .collect() + .await?; + assert_eq!(one.iter().map(|b| b.num_rows()).sum::(), 1); + println!("\nbank_account = 9000:"); + for batch in &one { + println!("{batch:?}"); + } + + // IN-list filter — the index resolves two rows. + let two = ctx + .sql("SELECT * FROM accounts WHERE bank_account IN (100, 1000)") + .await? + .collect() + .await?; + assert_eq!(two.iter().map(|b| b.num_rows()).sum::(), 2); + println!("\nbank_account IN (100, 1000):"); + for batch in &two { + println!("{batch:?}"); + } + + // NOT-IN filter — the index excludes one row, returning the other two. + let not_in = ctx + .sql("SELECT * FROM accounts WHERE bank_account != 9000") + .await? + .collect() + .await?; + assert_eq!(not_in.iter().map(|b| b.num_rows()).sum::(), 2); + println!("\nbank_account != 9000:"); + for batch in ¬_in { + println!("{batch:?}"); + } + + // Inexact pushdown — the filter references bank_account (indexed) AND id + // (not indexed). The index narrows the scan to bank_account IN (100, 1000), + // but DataFusion must re-check "id > 2" since we can't handle that part. + let inexact = ctx + .sql("SELECT * FROM accounts WHERE bank_account IN (100, 1000) AND id > 2") + .await? + .collect() + .await?; + assert_eq!(inexact.iter().map(|b| b.num_rows()).sum::(), 1); + println!("\nbank_account IN (100, 1000) AND id > 2 (inexact pushdown):"); + for batch in &inexact { + println!("{batch:?}"); + } + + Ok(()) +} + +/// A User, with an id and a bank account +#[derive(Clone, Debug)] +struct User { + id: u8, + bank_account: u64, +} + +/// A custom datasource with a `BTreeMap` index on `bank_account` +#[derive(Clone)] +struct CustomDataSource { + inner: Arc>, +} + +struct CustomDataSourceInner { + data: HashMap, + bank_account_index: BTreeMap, +} + +impl Debug for CustomDataSource { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.write_str("custom_db") + } +} + +impl Default for CustomDataSource { + fn default() -> Self { + Self { + inner: Arc::new(Mutex::new(CustomDataSourceInner { + data: Default::default(), + bank_account_index: Default::default(), + })), + } + } +} + +impl CustomDataSource { + fn populate_users(&self) { + for user in [ + User { + id: 1, + bank_account: 9_000, + }, + User { + id: 2, + bank_account: 100, + }, + User { + id: 3, + bank_account: 1_000, + }, + ] { + let mut inner = self.inner.lock().unwrap(); + inner.bank_account_index.insert(user.bank_account, user.id); + inner.data.insert(user.id, user); + } + } +} + +fn schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("id", DataType::UInt8, false), + Field::new("bank_account", DataType::UInt64, false), + ])) +} + +#[async_trait] +impl TableProvider for CustomDataSource { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + schema() + } + + fn table_type(&self) -> TableType { + TableType::Base + } + + fn supports_filters_pushdown( + &self, + filters: &[&Expr], + ) -> Result> { + Ok(filters + .iter() + .map(|f| { + let columns = f.column_refs(); + let has_bank_account = columns.iter().any(|c| c.name == "bank_account"); + let all_bank_account = + has_bank_account && columns.iter().all(|c| c.name == "bank_account"); + if all_bank_account { + // Only bank_account — fully handled via the index. + TableProviderFilterPushDown::Exact + } else if has_bank_account { + // Mixed columns — index narrows the scan, but DataFusion + // must re-check the parts we can't handle. + TableProviderFilterPushDown::Inexact + } else { + // No indexed columns — we can't help with this filter. + TableProviderFilterPushDown::Unsupported + } + }) + .collect()) + } + + async fn scan( + &self, + _state: &dyn Session, + projection: Option<&Vec>, + filters: &[Expr], + _limit: Option, + ) -> Result> { + // Convert the logical filter expressions into a single physical + // expression so we can feed it to LiteralGuarantee::analyze. + let filter = conjunction(filters.to_vec()).and_then(|expr| { + let physical_expr = logical2physical(&expr, self.schema().as_ref()); + let guarantees = LiteralGuarantee::analyze(&physical_expr); + + // Look for a guarantee on the `bank_account` column. + let guarantee = guarantees + .iter() + .find(|g| g.column.name() == "bank_account")?; + + // Resolve the literal values to user ids via the index. + let inner = self.inner.lock().unwrap(); + let ids: Vec = guarantee + .literals + .iter() + .filter_map(|v| { + if let ScalarValue::UInt64(Some(val)) = v { + inner.bank_account_index.get(val).copied() + } else { + None + } + }) + .collect(); + Some(IndexFilter { + ids, + guarantee: guarantee.guarantee, + }) + }); + + Ok(Arc::new(CustomExec::new(self.clone(), projection, filter))) + } +} + +/// Result of resolving a `LiteralGuarantee` against the index. +#[derive(Debug, Clone)] +struct IndexFilter { + ids: Vec, + guarantee: Guarantee, +} + +#[derive(Debug, Clone)] +struct CustomExec { + db: CustomDataSource, + projected_schema: SchemaRef, + /// When Some, filters users by including (In) or excluding (NotIn) IDs. + filter: Option, + cache: Arc, +} + +impl CustomExec { + fn new( + db: CustomDataSource, + projections: Option<&Vec>, + filter: Option, + ) -> Self { + let projected_schema = + datafusion::physical_plan::project_schema(&schema(), projections).unwrap(); + let cache = Arc::new(PlanProperties::new( + EquivalenceProperties::new(Arc::clone(&projected_schema)), + Partitioning::UnknownPartitioning(1), + EmissionType::Incremental, + Boundedness::Bounded, + )); + Self { + db, + projected_schema, + filter, + cache, + } + } +} + +impl DisplayAs for CustomExec { + fn fmt_as(&self, _t: DisplayFormatType, f: &mut Formatter) -> fmt::Result { + write!(f, "CustomExec(filter={:?})", self.filter) + } +} + +impl ExecutionPlan for CustomExec { + fn name(&self) -> &'static str { + "CustomExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &Arc { + &self.cache + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + _: Vec>, + ) -> Result> { + Ok(self) + } + + fn apply_expressions( + &self, + _f: &mut dyn FnMut( + &dyn datafusion::physical_plan::PhysicalExpr, + ) -> Result, + ) -> Result { + Ok(TreeNodeRecursion::Continue) + } + + fn execute( + &self, + _partition: usize, + _context: Arc, + ) -> Result { + let inner = self.db.inner.lock().unwrap(); + + // Collect the users to return based on the index filter. + let users: Vec<&User> = match &self.filter { + Some(IndexFilter { + ids, + guarantee: Guarantee::In, + }) => ids.iter().filter_map(|id| inner.data.get(id)).collect(), + Some(IndexFilter { + ids, + guarantee: Guarantee::NotIn, + }) => inner + .data + .values() + .filter(|u| !ids.contains(&u.id)) + .collect(), + None => inner.data.values().collect(), + }; + + let mut id_array = UInt8Builder::with_capacity(users.len()); + let mut account_array = UInt64Builder::with_capacity(users.len()); + for user in &users { + id_array.append_value(user.id); + account_array.append_value(user.bank_account); + } + + let batch = RecordBatch::try_new( + Arc::clone(&self.projected_schema), + vec![ + Arc::new(id_array.finish()), + Arc::new(account_array.finish()), + ], + )?; + + Ok(Box::pin(MemoryStream::try_new( + vec![batch], + self.schema(), + None, + )?)) + } +} diff --git a/datafusion-examples/examples/custom_data_source/main.rs b/datafusion-examples/examples/custom_data_source/main.rs index 0d21a6259112..eacf75c52371 100644 --- a/datafusion-examples/examples/custom_data_source/main.rs +++ b/datafusion-examples/examples/custom_data_source/main.rs @@ -21,7 +21,7 @@ //! //! ## Usage //! ```bash -//! cargo run --example custom_data_source -- [all|csv_json_opener|csv_sql_streaming|custom_datasource|custom_file_casts|custom_file_format|default_column_values|file_stream_provider] +//! cargo run --example custom_data_source -- [all|csv_json_opener|csv_sql_streaming|custom_datasource|custom_file_casts|custom_file_format|default_column_values|file_stream_provider|literal_guarantee_filter_pushdown] //! ``` //! //! Each subcommand runs a corresponding example: @@ -50,6 +50,12 @@ //! //! - `file_stream_provider` //! (file: file_stream_provider.rs, desc: Read/write via FileStreamProvider for streams) +//! +//! - `literal_guarantee_filter_pushdown` +//! (file: literal_guarantee_filter_pushdown.rs, desc: Filter pushdown using LiteralGuarantee with a custom index) +//! +//! - `pruning_predicate_filter_pushdown` +//! (file: pruning_predicate_filter_pushdown.rs, desc: Filter pushdown using PruningPredicate with partition statistics) mod adapter_serialization; mod csv_json_opener; @@ -59,6 +65,8 @@ mod custom_file_casts; mod custom_file_format; mod default_column_values; mod file_stream_provider; +mod literal_guarantee_filter_pushdown; +mod pruning_predicate_filter_pushdown; use datafusion::error::{DataFusionError, Result}; use strum::{IntoEnumIterator, VariantNames}; @@ -74,6 +82,8 @@ enum ExampleKind { CustomDatasource, CustomFileCasts, CustomFileFormat, + LiteralGuaranteeFilterPushdown, + PruningPredicateFilterPushdown, DefaultColumnValues, FileStreamProvider, } @@ -109,6 +119,14 @@ impl ExampleKind { ExampleKind::CustomFileFormat => { custom_file_format::custom_file_format().await? } + ExampleKind::LiteralGuaranteeFilterPushdown => { + literal_guarantee_filter_pushdown::literal_guarantee_filter_pushdown() + .await? + } + ExampleKind::PruningPredicateFilterPushdown => { + pruning_predicate_filter_pushdown::pruning_predicate_filter_pushdown() + .await? + } ExampleKind::DefaultColumnValues => { default_column_values::default_column_values().await? } diff --git a/datafusion-examples/examples/custom_data_source/pruning_predicate_filter_pushdown.rs b/datafusion-examples/examples/custom_data_source/pruning_predicate_filter_pushdown.rs new file mode 100644 index 000000000000..8a86be3a9a62 --- /dev/null +++ b/datafusion-examples/examples/custom_data_source/pruning_predicate_filter_pushdown.rs @@ -0,0 +1,394 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! See `main.rs` for how to run it. +//! +//! Example demonstrating filter pushdown using [`PruningPredicate`] to skip +//! entire partitions based on min/max statistics. +//! +//! A [`TableProvider`] with multiple in-memory partitions uses +//! [`PruningPredicate::prune`] to determine which partitions could contain +//! matching rows and only scans those. + +use std::any::Any; +use std::collections::HashSet; +use std::fmt::{self, Debug, Formatter}; +use std::sync::Arc; + +use async_trait::async_trait; +use datafusion::arrow::array::{ + ArrayRef, BooleanArray, UInt8Builder, UInt64Array, UInt64Builder, +}; +use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use datafusion::arrow::record_batch::RecordBatch; +use datafusion::catalog::Session; +use datafusion::common::pruning::PruningStatistics; +use datafusion::common::tree_node::TreeNodeRecursion; +use datafusion::datasource::{TableProvider, TableType}; +use datafusion::error::Result; +use datafusion::execution::context::TaskContext; +use datafusion::physical_expr::EquivalenceProperties; +use datafusion::physical_expr::planner::logical2physical; +use datafusion::physical_optimizer::pruning::PruningPredicate; +use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType}; +use datafusion::physical_plan::memory::MemoryStream; +use datafusion::physical_plan::{ + DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, + SendableRecordBatchStream, +}; +use datafusion::prelude::*; +use datafusion::scalar::ScalarValue; +use datafusion_expr::TableProviderFilterPushDown; +use datafusion_expr::utils::conjunction; + +pub async fn pruning_predicate_filter_pushdown() -> Result<()> { + let db = CustomDataSource::new(vec![ + Partition::new( + vec![ + User { + id: 1, + bank_account: 100, + }, + User { + id: 2, + bank_account: 500, + }, + ], + 100, + 500, + ), + Partition::new( + vec![ + User { + id: 3, + bank_account: 1_000, + }, + User { + id: 4, + bank_account: 4_000, + }, + ], + 1_000, + 4_000, + ), + Partition::new( + vec![ + User { + id: 5, + bank_account: 5_000, + }, + User { + id: 6, + bank_account: 9_000, + }, + ], + 5_000, + 9_000, + ), + ]); + + let ctx = SessionContext::new(); + ctx.register_table("accounts", Arc::new(db))?; + + // No filter: all 3 partitions scanned, 6 rows returned. + let all = ctx.sql("SELECT * FROM accounts").await?.collect().await?; + assert_eq!(all.iter().map(|b| b.num_rows()).sum::(), 6); + println!("All rows:"); + for batch in &all { + println!("{batch:?}"); + } + + // Equality filter: pruning skips partitions 0 and 1 because 9000 is + // outside their ranges. Partition 2 (5000–9000) is scanned, but it also + // contains bank_account=5000 which doesn't match. Because our pushdown + // is Inexact, DataFusion re-checks the filter and removes that row. + let one = ctx + .sql("SELECT * FROM accounts WHERE bank_account = 9000") + .await? + .collect() + .await?; + assert_eq!(one.iter().map(|b| b.num_rows()).sum::(), 1); + println!("\nbank_account = 9000:"); + for batch in &one { + println!("{batch:?}"); + } + + // Range filter: only partition 1 (1000–4000) needs scanning. + let range = ctx + .sql("SELECT * FROM accounts WHERE bank_account >= 1000 AND bank_account <= 4000") + .await? + .collect() + .await?; + assert_eq!(range.iter().map(|b| b.num_rows()).sum::(), 2); + println!("\nbank_account >= 1000 AND bank_account <= 4000:"); + for batch in &range { + println!("{batch:?}"); + } + + // Inexact pushdown: pruning keeps partitions 1 and 2 because + // both overlap with the range > 500, skipping only partition 0 (100–500). + // But within the kept partitions, DataFusion re-checks and removes rows + // that don't satisfy bank_account > 4000 (e.g. bank_account=1000 from + // partition 1). + let inexact = ctx + .sql("SELECT * FROM accounts WHERE bank_account > 4000") + .await? + .collect() + .await?; + assert_eq!(inexact.iter().map(|b| b.num_rows()).sum::(), 2); + println!("\nbank_account > 4000 (inexact: partition 1 scanned but no rows match):"); + for batch in &inexact { + println!("{batch:?}"); + } + + Ok(()) +} + +/// A User, with an id and a bank account +#[derive(Clone, Debug)] +struct User { + id: u8, + bank_account: u64, +} + +/// A partition of users with known min/max bank_account statistics +#[derive(Clone, Debug)] +struct Partition { + users: Vec, + min_account: u64, + max_account: u64, +} + +impl Partition { + fn new(users: Vec, min_account: u64, max_account: u64) -> Self { + Self { + users, + min_account, + max_account, + } + } +} + +/// A custom datasource with partitions that have known min/max statistics +#[derive(Clone, Debug)] +struct CustomDataSource { + partitions: Vec, +} + +impl CustomDataSource { + fn new(partitions: Vec) -> Self { + Self { partitions } + } +} + +fn schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("id", DataType::UInt8, false), + Field::new("bank_account", DataType::UInt64, false), + ])) +} + +/// Provides min/max statistics for each partition so [`PruningPredicate`] +/// can decide which partitions to skip. +struct PartitionStats<'a> { + partitions: &'a [Partition], +} + +impl PruningStatistics for PartitionStats<'_> { + fn min_values(&self, column: &Column) -> Option { + if column.name == "bank_account" { + let mins: Vec = self.partitions.iter().map(|p| p.min_account).collect(); + Some(Arc::new(UInt64Array::from(mins))) + } else { + None + } + } + + fn max_values(&self, column: &Column) -> Option { + if column.name == "bank_account" { + let maxes: Vec = self.partitions.iter().map(|p| p.max_account).collect(); + Some(Arc::new(UInt64Array::from(maxes))) + } else { + None + } + } + + fn num_containers(&self) -> usize { + self.partitions.len() + } + + fn null_counts(&self, _column: &Column) -> Option { + None + } + + fn row_counts(&self, _column: &Column) -> Option { + None + } + + fn contained( + &self, + _column: &Column, + _values: &HashSet, + ) -> Option { + None + } +} + +#[async_trait] +impl TableProvider for CustomDataSource { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + schema() + } + + fn table_type(&self) -> TableType { + TableType::Base + } + + fn supports_filters_pushdown( + &self, + filters: &[&Expr], + ) -> Result> { + // Pruning narrows which partitions to scan but doesn't guarantee + // every row in a scanned partition matches, so Inexact is correct. + Ok(vec![TableProviderFilterPushDown::Inexact; filters.len()]) + } + + async fn scan( + &self, + _state: &dyn Session, + projection: Option<&Vec>, + filters: &[Expr], + _limit: Option, + ) -> Result> { + // Use PruningPredicate to determine which partitions to scan. + let partitions_to_scan = if let Some(expr) = conjunction(filters.to_vec()) { + let physical_expr = logical2physical(&expr, self.schema().as_ref()); + let predicate = PruningPredicate::try_new(physical_expr, self.schema())?; + let keep = predicate.prune(&PartitionStats { + partitions: &self.partitions, + })?; + self.partitions + .iter() + .zip(keep) + .filter(|(_, keep)| *keep) + .map(|(p, _)| p.clone()) + .collect() + } else { + self.partitions.clone() + }; + + Ok(Arc::new(CustomExec::new(partitions_to_scan, projection))) + } +} + +#[derive(Debug, Clone)] +struct CustomExec { + partitions: Vec, + projected_schema: SchemaRef, + cache: Arc, +} + +impl CustomExec { + fn new(partitions: Vec, projections: Option<&Vec>) -> Self { + let projected_schema = + datafusion::physical_plan::project_schema(&schema(), projections).unwrap(); + let cache = Arc::new(PlanProperties::new( + EquivalenceProperties::new(Arc::clone(&projected_schema)), + Partitioning::UnknownPartitioning(1), + EmissionType::Incremental, + Boundedness::Bounded, + )); + Self { + partitions, + projected_schema, + cache, + } + } +} + +impl DisplayAs for CustomExec { + fn fmt_as(&self, _t: DisplayFormatType, f: &mut Formatter) -> fmt::Result { + write!(f, "CustomExec(partitions={})", self.partitions.len()) + } +} + +impl ExecutionPlan for CustomExec { + fn name(&self) -> &'static str { + "CustomExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &Arc { + &self.cache + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + _: Vec>, + ) -> Result> { + Ok(self) + } + + fn apply_expressions( + &self, + _f: &mut dyn FnMut( + &dyn datafusion::physical_plan::PhysicalExpr, + ) -> Result, + ) -> Result { + Ok(TreeNodeRecursion::Continue) + } + + fn execute( + &self, + _partition: usize, + _context: Arc, + ) -> Result { + let users: Vec<&User> = self.partitions.iter().flat_map(|p| &p.users).collect(); + + let mut id_array = UInt8Builder::with_capacity(users.len()); + let mut account_array = UInt64Builder::with_capacity(users.len()); + for user in &users { + id_array.append_value(user.id); + account_array.append_value(user.bank_account); + } + + let batch = RecordBatch::try_new( + Arc::clone(&self.projected_schema), + vec![ + Arc::new(id_array.finish()), + Arc::new(account_array.finish()), + ], + )?; + + Ok(Box::pin(MemoryStream::try_new( + vec![batch], + self.schema(), + None, + )?)) + } +}