Skip to content

Commit 42e6d88

Browse files
committed
test: add context drop test for SessionContext to prevent segfaults
1 parent e20046a commit 42e6d88

2 files changed

Lines changed: 37 additions & 16 deletions

File tree

python/tests/test_dataframe.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1832,6 +1832,22 @@ def test_arrow_c_stream_capsule_manual_destructor_noop(ctx):
18321832
gc.collect()
18331833

18341834

1835+
def test_arrow_c_stream_context_drop_no_segfault():
1836+
"""Repeatedly create/drop SessionContext after __arrow_c_stream__."""
1837+
for _ in range(5):
1838+
ctx = SessionContext()
1839+
df = ctx.sql("SELECT 1 AS a")
1840+
capsule = df.__arrow_c_stream__()
1841+
del df
1842+
del ctx
1843+
reader = pa.RecordBatchReader._import_from_c_capsule(capsule)
1844+
del capsule
1845+
table = reader.read_all()
1846+
assert table.num_rows == 1
1847+
del reader
1848+
gc.collect()
1849+
1850+
18351851
def test_arrow_stream_to_pylist(df):
18361852
capsule = df.__arrow_c_stream__()
18371853
reader = pa.RecordBatchReader._import_from_c_capsule(capsule)

src/dataframe.rs

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
use std::collections::HashMap;
1919
use std::ffi::CString;
2020
use std::sync::Arc;
21+
use tokio::sync::Mutex;
2122

2223
use arrow::array::{new_null_array, RecordBatch, RecordBatchIterator, RecordBatchReader};
2324
use arrow::compute::can_cast_types;
@@ -33,8 +34,8 @@ use datafusion::config::{CsvOptions, ParquetColumnOptions, ParquetOptions, Table
3334
use datafusion::dataframe::{DataFrame, DataFrameWriteOptions};
3435
use datafusion::datasource::TableProvider;
3536
use datafusion::error::DataFusionError;
36-
use datafusion::execution::SendableRecordBatchStream;
3737
use datafusion::execution::session_state::SessionState;
38+
use datafusion::execution::SendableRecordBatchStream;
3839
use datafusion::parquet::basic::{BrotliLevel, Compression, GzipLevel, ZstdLevel};
3940
use datafusion::prelude::*;
4041
use datafusion_ffi::table_provider::FFI_TableProvider;
@@ -360,7 +361,7 @@ impl PyDataFrame {
360361
/// their original partition order. When a `projection` is set, each batch is
361362
/// converted via `record_batch_into_schema` to apply schema changes per batch.
362363
struct PartitionedDataFrameStreamReader {
363-
streams: Vec<SendableRecordBatchStream>,
364+
streams: Vec<Arc<Mutex<SendableRecordBatchStream>>>,
364365
// Hold a reference to the session state to keep the context alive
365366
_state: Arc<SessionState>,
366367
schema: SchemaRef,
@@ -373,12 +374,17 @@ impl Iterator for PartitionedDataFrameStreamReader {
373374

374375
fn next(&mut self) -> Option<Self::Item> {
375376
while self.current < self.streams.len() {
376-
let stream = &mut self.streams[self.current];
377-
let fut = poll_next_batch(stream);
378-
let result = Python::with_gil(|py| wait_for_future(py, fut));
377+
let stream = self.streams[self.current].clone();
378+
379+
let result = Python::with_gil(|py| {
380+
spawn_future(py, async move {
381+
let mut s = stream.lock().await;
382+
poll_next_batch(&mut s).await
383+
})
384+
});
379385

380386
match result {
381-
Ok(Ok(Some(batch))) => {
387+
Ok(Some(batch)) => {
382388
let batch = if let Some(ref schema) = self.projection {
383389
match record_batch_into_schema(batch, schema.as_ref()) {
384390
Ok(b) => b,
@@ -389,13 +395,10 @@ impl Iterator for PartitionedDataFrameStreamReader {
389395
};
390396
return Some(Ok(batch));
391397
}
392-
Ok(Ok(None)) => {
398+
Ok(None) => {
393399
self.current += 1;
394400
continue;
395401
}
396-
Ok(Err(e)) => {
397-
return Some(Err(ArrowError::ExternalError(Box::new(e))));
398-
}
399402
Err(e) => {
400403
return Some(Err(ArrowError::ExternalError(Box::new(e))));
401404
}
@@ -956,6 +959,10 @@ impl PyDataFrame {
956959
let df = self.df.as_ref().clone();
957960
let state = df.session_state().clone();
958961
let streams = spawn_future(py, async move { df.execute_stream_partitioned().await })?;
962+
let streams = streams
963+
.into_iter()
964+
.map(|s| Arc::new(Mutex::new(s)))
965+
.collect();
959966

960967
let mut schema: Schema = self.df.schema().to_owned().into();
961968
let mut projection: Option<SchemaRef> = None;
@@ -999,12 +1006,10 @@ impl PyDataFrame {
9991006
let df = self.df.as_ref().clone();
10001007
let state = df.session_state().clone();
10011008
let streams = spawn_future(py, async move { df.execute_stream_partitioned().await })?;
1002-
Ok(
1003-
streams
1004-
.into_iter()
1005-
.map(|stream| PyRecordBatchStream::new(stream, state.clone()))
1006-
.collect(),
1007-
)
1009+
Ok(streams
1010+
.into_iter()
1011+
.map(|stream| PyRecordBatchStream::new(stream, state.clone()))
1012+
.collect())
10081013
}
10091014

10101015
/// Convert to pandas dataframe with pyarrow

0 commit comments

Comments
 (0)