Skip to content

Commit 61455f6

Browse files
committed
add tests
1 parent 3da2ff8 commit 61455f6

1 file changed

Lines changed: 83 additions & 1 deletion

File tree

datafusion/physical-plan/src/sorts/builder.rs

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ use arrow::error::ArrowError;
2222
use arrow::record_batch::RecordBatch;
2323
use datafusion_common::{DataFusionError, Result};
2424
use datafusion_execution::memory_pool::MemoryReservation;
25+
use std::panic::{catch_unwind, AssertUnwindSafe};
2526
use std::sync::Arc;
2627

2728
#[derive(Debug, Copy, Clone, Default)]
@@ -139,7 +140,22 @@ impl BatchBuilder {
139140
.iter()
140141
.map(|(_, batch)| batch.column(column_idx).as_ref())
141142
.collect();
142-
Ok(interleave(&arrays, indices)?)
143+
// Arrow's interleave panics on i32 offset overflow with
144+
// `.expect("overflow")`. Catch that panic so the caller
145+
// can retry with fewer rows.
146+
match catch_unwind(AssertUnwindSafe(|| interleave(&arrays, indices))) {
147+
Ok(result) => Ok(result?),
148+
Err(panic_payload) => {
149+
if is_overflow_panic(&panic_payload) {
150+
Err(DataFusionError::ArrowError(
151+
Box::new(ArrowError::OffsetOverflowError(0)),
152+
None,
153+
))
154+
} else {
155+
std::panic::resume_unwind(panic_payload);
156+
}
157+
}
158+
}
143159
})
144160
.collect::<Result<Vec<_>>>()
145161
}
@@ -257,3 +273,69 @@ fn is_offset_overflow(e: &DataFusionError) -> bool {
257273
)
258274
>>>>>>> 967cf0a65 (Fix sort merge interleave overflow)
259275
}
276+
277+
/// Returns true if a caught panic payload looks like an Arrow offset overflow.
278+
fn is_overflow_panic(payload: &Box<dyn std::any::Any + Send>) -> bool {
279+
if let Some(msg) = payload.downcast_ref::<&str>() {
280+
return msg.contains("overflow");
281+
}
282+
if let Some(msg) = payload.downcast_ref::<String>() {
283+
return msg.contains("overflow");
284+
}
285+
false
286+
}
287+
288+
#[cfg(test)]
289+
mod tests {
290+
use super::*;
291+
use arrow::array::StringArray;
292+
use arrow::datatypes::{DataType, Field, Schema};
293+
use datafusion_execution::memory_pool::{MemoryConsumer, MemoryPool, UnboundedMemoryPool};
294+
295+
/// Test that interleaving string columns whose combined byte length
296+
/// exceeds i32::MAX does not panic. Arrow's `interleave` panics with
297+
/// `.expect("overflow")` in this case; `BatchBuilder` catches the
298+
/// panic and retries with fewer rows until the output fits in i32
299+
/// offsets.
300+
#[test]
301+
fn test_interleave_overflow_is_caught() {
302+
// Each string is ~768 MB. Three rows total → ~2.3 GB > i32::MAX.
303+
let big_str: String = "x".repeat(768 * 1024 * 1024);
304+
305+
let schema = Arc::new(Schema::new(vec![Field::new(
306+
"s",
307+
DataType::Utf8,
308+
false,
309+
)]));
310+
311+
let pool: Arc<dyn MemoryPool> = Arc::new(UnboundedMemoryPool::default());
312+
let reservation = MemoryConsumer::new("test").register(&pool);
313+
let mut builder = BatchBuilder::new(
314+
Arc::clone(&schema),
315+
/* stream_count */ 3,
316+
/* batch_size */ 16,
317+
reservation,
318+
);
319+
320+
// Push one batch per stream, each containing one large string.
321+
for stream_idx in 0..3 {
322+
let array = StringArray::from(vec![big_str.as_str()]);
323+
let batch =
324+
RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(array)])
325+
.unwrap();
326+
builder.push_batch(stream_idx, batch).unwrap();
327+
builder.push_row(stream_idx);
328+
}
329+
330+
// 3 rows total; interleaving all 3 would overflow i32 offsets.
331+
// The retry loop should halve until it succeeds.
332+
let batch = builder.build_record_batch().unwrap().unwrap();
333+
assert!(batch.num_rows() > 0);
334+
assert!(batch.num_rows() < 3);
335+
336+
// Drain remaining rows.
337+
let batch2 = builder.build_record_batch().unwrap().unwrap();
338+
assert!(batch2.num_rows() > 0);
339+
assert_eq!(batch.num_rows() + batch2.num_rows(), 3);
340+
}
341+
}

0 commit comments

Comments
 (0)