Skip to content

Commit cd99b1d

Browse files
committed
Enhance ArrowArrayStream handling by adding ARROW_STREAM_NAME and improving memory management in drop_stream function
1 parent 137f8ee commit cd99b1d

1 file changed

Lines changed: 25 additions & 13 deletions

File tree

src/dataframe.rs

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ use arrow::array::{new_null_array, RecordBatch, RecordBatchReader};
2323
use arrow::compute::can_cast_types;
2424
use arrow::error::ArrowError;
2525
use arrow::ffi::FFI_ArrowSchema;
26-
use arrow::ffi_stream::FFI_ArrowArrayStream;
26+
use arrow::ffi_stream::{FFI_ArrowArrayStream, ARROW_STREAM_NAME};
2727
use arrow::pyarrow::FromPyArrow;
2828
use datafusion::arrow::datatypes::{Schema, SchemaRef};
2929
use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow};
@@ -39,6 +39,7 @@ use datafusion::prelude::*;
3939
use datafusion_ffi::table_provider::FFI_TableProvider;
4040
use futures::{StreamExt, TryStreamExt};
4141
use pyo3::exceptions::PyValueError;
42+
use pyo3::ffi;
4243
use pyo3::prelude::*;
4344
use pyo3::pybacked::PyBackedStr;
4445
use pyo3::types::{PyCapsule, PyList, PyTuple, PyTupleMethods};
@@ -58,6 +59,17 @@ use crate::{
5859
expr::{sort_expr::PySortExpr, PyExpr},
5960
};
6061

62+
unsafe extern "C" fn drop_stream(capsule: *mut ffi::PyObject) {
63+
if capsule.is_null() {
64+
return;
65+
}
66+
let name = CString::new(ARROW_STREAM_NAME).unwrap();
67+
let stream_ptr = ffi::PyCapsule_GetPointer(capsule, name.as_ptr()) as *mut FFI_ArrowArrayStream;
68+
if !stream_ptr.is_null() {
69+
drop(Box::from_raw(stream_ptr));
70+
}
71+
}
72+
6173
// https://github.com/apache/datafusion-python/pull/1016#discussion_r1983239116
6274
// - we have not decided on the table_provider approach yet
6375
// this is an interim implementation
@@ -958,20 +970,20 @@ impl PyDataFrame {
958970
!stream_ptr.is_null(),
959971
"ArrowArrayStream pointer should never be null"
960972
);
961-
let stream_capsule_name = CString::new("arrow_array_stream").unwrap();
962-
unsafe {
963-
PyCapsule::new_bound_with_destructor(
964-
py,
965-
stream_ptr,
966-
Some(stream_capsule_name),
967-
|ptr: *mut FFI_ArrowArrayStream, _| {
968-
if !ptr.is_null() {
969-
unsafe { Box::from_raw(ptr) };
970-
}
971-
},
973+
let stream_name = CString::new(ARROW_STREAM_NAME).unwrap();
974+
let capsule = unsafe {
975+
ffi::PyCapsule_New(
976+
stream_ptr as *mut c_void,
977+
stream_name.as_ptr(),
978+
Some(drop_stream),
972979
)
980+
};
981+
if capsule.is_null() {
982+
unsafe { drop(Box::from_raw(stream_ptr)) };
983+
Err(PyErr::fetch(py).into())
984+
} else {
985+
Ok(unsafe { Bound::from_owned_ptr(py, capsule) })
973986
}
974-
.map_err(PyDataFusionError::from)
975987
}
976988

977989
fn execute_stream(&self, py: Python) -> PyDataFusionResult<PyRecordBatchStream> {

0 commit comments

Comments
 (0)