Skip to content

Commit bbfc11f

Browse files
committed
Add support for getting inner ffi types on foreign types that use boxed traits
1 parent de4391f commit bbfc11f

10 files changed

Lines changed: 36 additions & 24 deletions

File tree

datafusion/expr-common/src/accumulator.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ use std::fmt::Debug;
4848
/// [`evaluate`]: Self::evaluate
4949
/// [`merge_batch`]: Self::merge_batch
5050
/// [window function]: https://en.wikipedia.org/wiki/Window_function_(SQL)
51-
pub trait Accumulator: Send + Sync + Debug {
51+
pub trait Accumulator: Send + Sync + Debug + std::any::Any {
5252
/// Updates the accumulator's state from its input.
5353
///
5454
/// `values` contains the arguments to this aggregate function.

datafusion/expr-common/src/groups_accumulator.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ impl EmitTo {
108108
///
109109
/// [`Accumulator`]: crate::accumulator::Accumulator
110110
/// [Aggregating Millions of Groups Fast blog]: https://arrow.apache.org/blog/2023/08/05/datafusion_fast_grouping/
111-
pub trait GroupsAccumulator: Send {
111+
pub trait GroupsAccumulator: Send + std::any::Any {
112112
/// Updates the accumulator's state from its arguments, encoded as
113113
/// a vector of [`ArrayRef`]s.
114114
///

datafusion/expr/src/partition_evaluator.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ use crate::window_state::WindowAggState;
9090
/// For more background, please also see the [User defined Window Functions in DataFusion blog]
9191
///
9292
/// [User defined Window Functions in DataFusion blog]: https://datafusion.apache.org/blog/2025/04/19/user-defined-window-functions
93-
pub trait PartitionEvaluator: Debug + Send {
93+
pub trait PartitionEvaluator: Debug + Send + std::any::Any {
9494
/// When the window frame has a fixed beginning (e.g UNBOUNDED
9595
/// PRECEDING), some functions such as FIRST_VALUE, LAST_VALUE and
9696
/// NTH_VALUE do not need the (unbounded) input once they have

datafusion/ffi/src/udaf/accumulator.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
use std::any::Any;
1819
use std::ffi::c_void;
1920
use std::ops::Deref;
2021
use std::ptr::null_mut;
@@ -204,9 +205,12 @@ unsafe extern "C" fn release_fn_wrapper(accumulator: &mut FFI_Accumulator) {
204205

205206
impl From<Box<dyn Accumulator>> for FFI_Accumulator {
206207
fn from(accumulator: Box<dyn Accumulator>) -> Self {
207-
// if let Some(accumulator) = accumulator.into_any().downcast::<ForeignAccumulator>() {
208-
// return accumulator.accumulator;
209-
// }
208+
if (accumulator.as_ref() as &dyn Any).is::<ForeignAccumulator>() {
209+
let accumulator = (accumulator as Box<dyn Any>)
210+
.downcast::<ForeignAccumulator>()
211+
.expect("already checked type");
212+
return accumulator.accumulator;
213+
}
210214

211215
let supports_retract_batch = accumulator.supports_retract_batch();
212216
let private_data = AccumulatorPrivateData { accumulator };

datafusion/ffi/src/udaf/groups_accumulator.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
use std::any::Any;
1819
use std::ffi::c_void;
1920
use std::ops::Deref;
2021
use std::ptr::null_mut;
@@ -245,9 +246,12 @@ unsafe extern "C" fn release_fn_wrapper(accumulator: &mut FFI_GroupsAccumulator)
245246

246247
impl From<Box<dyn GroupsAccumulator>> for FFI_GroupsAccumulator {
247248
fn from(accumulator: Box<dyn GroupsAccumulator>) -> Self {
248-
// if let Some(accumulator) = accumulator.into_any().downcast_ref::<ForeignGroupsAccumulator>() {
249-
// return accumulator.accumulator;
250-
// }
249+
if (accumulator.as_ref() as &dyn Any).is::<ForeignGroupsAccumulator>() {
250+
let accumulator = (accumulator as Box<dyn Any>)
251+
.downcast::<ForeignGroupsAccumulator>()
252+
.expect("already checked type");
253+
return accumulator.accumulator;
254+
}
251255

252256
let supports_convert_to_state = accumulator.supports_convert_to_state();
253257
let private_data = GroupsAccumulatorPrivateData { accumulator };

datafusion/ffi/src/udwf/partition_evaluator.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
use std::any::Any;
1819
use std::ffi::c_void;
1920
use std::ops::Range;
2021

@@ -205,9 +206,12 @@ unsafe extern "C" fn release_fn_wrapper(evaluator: &mut FFI_PartitionEvaluator)
205206

206207
impl From<Box<dyn PartitionEvaluator>> for FFI_PartitionEvaluator {
207208
fn from(evaluator: Box<dyn PartitionEvaluator>) -> Self {
208-
// if let Ok(evaluator) = evaluator.into_any().downcast::<ForeignPartitionEvaluator>() {
209-
// return evaluator.evaluator;
210-
// }
209+
if (evaluator.as_ref() as &dyn Any).is::<ForeignPartitionEvaluator>() {
210+
let evaluator = (evaluator as Box<dyn Any>)
211+
.downcast::<ForeignPartitionEvaluator>()
212+
.expect("already checked type");
213+
return evaluator.evaluator;
214+
}
211215

212216
let is_causal = evaluator.is_causal();
213217
let supports_bounded_execution = evaluator.supports_bounded_execution();

datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/bool_op.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ use super::accumulate::NullState;
3737
#[derive(Debug)]
3838
pub struct BooleanGroupsAccumulator<F>
3939
where
40-
F: Fn(bool, bool) -> bool + Send + Sync,
40+
F: Fn(bool, bool) -> bool + Send + Sync + 'static,
4141
{
4242
/// values per group
4343
values: BooleanBufferBuilder,
@@ -55,7 +55,7 @@ where
5555

5656
impl<F> BooleanGroupsAccumulator<F>
5757
where
58-
F: Fn(bool, bool) -> bool + Send + Sync,
58+
F: Fn(bool, bool) -> bool + Send + Sync + 'static,
5959
{
6060
pub fn new(bool_fn: F, identity: bool) -> Self {
6161
Self {
@@ -69,7 +69,7 @@ where
6969

7070
impl<F> GroupsAccumulator for BooleanGroupsAccumulator<F>
7171
where
72-
F: Fn(bool, bool) -> bool + Send + Sync,
72+
F: Fn(bool, bool) -> bool + Send + Sync + 'static,
7373
{
7474
fn update_batch(
7575
&mut self,

datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/prim_op.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ use super::accumulate::NullState;
4141
pub struct PrimitiveGroupsAccumulator<T, F>
4242
where
4343
T: ArrowPrimitiveType + Send,
44-
F: Fn(&mut T::Native, T::Native) + Send + Sync,
44+
F: Fn(&mut T::Native, T::Native) + Send + Sync + 'static,
4545
{
4646
/// values per group, stored as the native type
4747
values: Vec<T::Native>,
@@ -62,7 +62,7 @@ where
6262
impl<T, F> PrimitiveGroupsAccumulator<T, F>
6363
where
6464
T: ArrowPrimitiveType + Send,
65-
F: Fn(&mut T::Native, T::Native) + Send + Sync,
65+
F: Fn(&mut T::Native, T::Native) + Send + Sync + 'static,
6666
{
6767
pub fn new(data_type: &DataType, prim_fn: F) -> Self {
6868
Self {
@@ -84,7 +84,7 @@ where
8484
impl<T, F> GroupsAccumulator for PrimitiveGroupsAccumulator<T, F>
8585
where
8686
T: ArrowPrimitiveType + Send,
87-
F: Fn(&mut T::Native, T::Native) + Send + Sync,
87+
F: Fn(&mut T::Native, T::Native) + Send + Sync + 'static,
8888
{
8989
fn update_batch(
9090
&mut self,

datafusion/functions-aggregate/src/average.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -754,7 +754,7 @@ impl Accumulator for DurationAvgAccumulator {
754754
struct AvgGroupsAccumulator<T, F>
755755
where
756756
T: ArrowNumericType + Send,
757-
F: Fn(T::Native, u64) -> Result<T::Native> + Send,
757+
F: Fn(T::Native, u64) -> Result<T::Native> + Send + 'static,
758758
{
759759
/// The type of the internal sum
760760
sum_data_type: DataType,
@@ -778,7 +778,7 @@ where
778778
impl<T, F> AvgGroupsAccumulator<T, F>
779779
where
780780
T: ArrowNumericType + Send,
781-
F: Fn(T::Native, u64) -> Result<T::Native> + Send,
781+
F: Fn(T::Native, u64) -> Result<T::Native> + Send + 'static,
782782
{
783783
pub fn new(sum_data_type: &DataType, return_data_type: &DataType, avg_fn: F) -> Self {
784784
debug!(
@@ -800,7 +800,7 @@ where
800800
impl<T, F> GroupsAccumulator for AvgGroupsAccumulator<T, F>
801801
where
802802
T: ArrowNumericType + Send,
803-
F: Fn(T::Native, u64) -> Result<T::Native> + Send,
803+
F: Fn(T::Native, u64) -> Result<T::Native> + Send + 'static,
804804
{
805805
fn update_batch(
806806
&mut self,

datafusion/spark/src/function/aggregate/avg.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ impl Accumulator for AvgAccumulator {
213213
struct AvgGroupsAccumulator<T, F>
214214
where
215215
T: ArrowNumericType + Send,
216-
F: Fn(T::Native, i64) -> Result<T::Native> + Send,
216+
F: Fn(T::Native, i64) -> Result<T::Native> + Send + 'static,
217217
{
218218
/// The type of the returned average
219219
return_data_type: DataType,
@@ -231,7 +231,7 @@ where
231231
impl<T, F> AvgGroupsAccumulator<T, F>
232232
where
233233
T: ArrowNumericType + Send,
234-
F: Fn(T::Native, i64) -> Result<T::Native> + Send,
234+
F: Fn(T::Native, i64) -> Result<T::Native> + Send + 'static,
235235
{
236236
pub fn new(return_data_type: &DataType, avg_fn: F) -> Self {
237237
Self {
@@ -246,7 +246,7 @@ where
246246
impl<T, F> GroupsAccumulator for AvgGroupsAccumulator<T, F>
247247
where
248248
T: ArrowNumericType + Send,
249-
F: Fn(T::Native, i64) -> Result<T::Native> + Send,
249+
F: Fn(T::Native, i64) -> Result<T::Native> + Send + 'static,
250250
{
251251
fn update_batch(
252252
&mut self,

0 commit comments

Comments
 (0)