Skip to content

Commit b2ca38a

Browse files
committed
add tests
1 parent 967cf0a commit b2ca38a

1 file changed

Lines changed: 83 additions & 1 deletion

File tree

datafusion/physical-plan/src/sorts/builder.rs

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ use arrow::error::ArrowError;
2222
use arrow::record_batch::RecordBatch;
2323
use datafusion_common::{DataFusionError, Result};
2424
use datafusion_execution::memory_pool::MemoryReservation;
25+
use std::panic::{catch_unwind, AssertUnwindSafe};
2526
use std::sync::Arc;
2627

2728
#[derive(Debug, Copy, Clone, Default)]
@@ -117,7 +118,22 @@ impl BatchBuilder {
117118
.iter()
118119
.map(|(_, batch)| batch.column(column_idx).as_ref())
119120
.collect();
120-
Ok(interleave(&arrays, indices)?)
121+
// Arrow's interleave panics on i32 offset overflow with
122+
// `.expect("overflow")`. Catch that panic so the caller
123+
// can retry with fewer rows.
124+
match catch_unwind(AssertUnwindSafe(|| interleave(&arrays, indices))) {
125+
Ok(result) => Ok(result?),
126+
Err(panic_payload) => {
127+
if is_overflow_panic(&panic_payload) {
128+
Err(DataFusionError::ArrowError(
129+
Box::new(ArrowError::OffsetOverflowError(0)),
130+
None,
131+
))
132+
} else {
133+
std::panic::resume_unwind(panic_payload);
134+
}
135+
}
136+
}
121137
})
122138
.collect::<Result<Vec<_>>>()
123139
}
@@ -189,3 +205,69 @@ fn is_offset_overflow(e: &DataFusionError) -> bool {
189205
if matches!(err.as_ref(), ArrowError::OffsetOverflowError(_))
190206
)
191207
}
208+
209+
/// Returns true if a caught panic payload looks like an Arrow offset overflow.
210+
fn is_overflow_panic(payload: &Box<dyn std::any::Any + Send>) -> bool {
211+
if let Some(msg) = payload.downcast_ref::<&str>() {
212+
return msg.contains("overflow");
213+
}
214+
if let Some(msg) = payload.downcast_ref::<String>() {
215+
return msg.contains("overflow");
216+
}
217+
false
218+
}
219+
220+
#[cfg(test)]
221+
mod tests {
222+
use super::*;
223+
use arrow::array::StringArray;
224+
use arrow::datatypes::{DataType, Field, Schema};
225+
use datafusion_execution::memory_pool::{MemoryConsumer, MemoryPool, UnboundedMemoryPool};
226+
227+
/// Test that interleaving string columns whose combined byte length
228+
/// exceeds i32::MAX does not panic. Arrow's `interleave` panics with
229+
/// `.expect("overflow")` in this case; `BatchBuilder` catches the
230+
/// panic and retries with fewer rows until the output fits in i32
231+
/// offsets.
232+
#[test]
233+
fn test_interleave_overflow_is_caught() {
234+
// Each string is ~768 MB. Three rows total → ~2.3 GB > i32::MAX.
235+
let big_str: String = "x".repeat(768 * 1024 * 1024);
236+
237+
let schema = Arc::new(Schema::new(vec![Field::new(
238+
"s",
239+
DataType::Utf8,
240+
false,
241+
)]));
242+
243+
let pool: Arc<dyn MemoryPool> = Arc::new(UnboundedMemoryPool::default());
244+
let reservation = MemoryConsumer::new("test").register(&pool);
245+
let mut builder = BatchBuilder::new(
246+
Arc::clone(&schema),
247+
/* stream_count */ 3,
248+
/* batch_size */ 16,
249+
reservation,
250+
);
251+
252+
// Push one batch per stream, each containing one large string.
253+
for stream_idx in 0..3 {
254+
let array = StringArray::from(vec![big_str.as_str()]);
255+
let batch =
256+
RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(array)])
257+
.unwrap();
258+
builder.push_batch(stream_idx, batch).unwrap();
259+
builder.push_row(stream_idx);
260+
}
261+
262+
// 3 rows total; interleaving all 3 would overflow i32 offsets.
263+
// The retry loop should halve until it succeeds.
264+
let batch = builder.build_record_batch().unwrap().unwrap();
265+
assert!(batch.num_rows() > 0);
266+
assert!(batch.num_rows() < 3);
267+
268+
// Drain remaining rows.
269+
let batch2 = builder.build_record_batch().unwrap().unwrap();
270+
assert!(batch2.num_rows() > 0);
271+
assert_eq!(batch.num_rows() + batch2.num_rows(), 3);
272+
}
273+
}

0 commit comments

Comments
 (0)