Skip to content

Commit e7f7fa9

Browse files
2010YOUY01comphead
andauthored
fix: Validate spill read schema (#21738)
## Which issue does this PR close? <!-- We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. For example `Closes #123` indicates that this PR will close issue #123. --> - Closes #. ## Rationale for this change <!-- Why are you proposing this change? If this is already explained clearly in the issue then this section is not needed. Explaining clearly why changes are proposed helps reviewers understand your changes and offer better suggestions for fixes. --> Follow-up to a review comment in : #21713 (comment) Not a bug fix, this PR tries to be more defensive and catch potential bugs. Before, when you write a spill file from a `SpillManager`, then read with another `SpillManager` of different schema, it would succeed. This is not a expected use pattern, an error will get propagated to the caller, and become harder to debug. This PR validates the schema when reading the first batch, and fail fast if the schema does not match. Note it only validates the schema, if two `SpillManager`s with the same schema do read and write, it's still allowed, but this is not a expected use pattern. Validating this case requires assigning `SpillManager` UID, and add that to the Arrow IPC file metadata, can be tricky, so leave this as TODO for simplicity. ## What changes are included in this PR? <!-- There is no need to duplicate the description in the issue here but it is sometimes worth providing a summary of the individual changes in this PR. --> ## Are these changes tested? <!-- We typically require tests for all PRs in order to: 1. Prevent the code from being accidentally broken by subsequent changes 2. Serve as another way to document the expected behavior of the code If tests are not included in your PR, please explain why (for example, are they covered by existing tests)? --> UTs ## Are there any user-facing changes? <!-- If there are user-facing changes then we may require documentation to be updated before approving the PR. --> No <!-- If there are any breaking changes to public APIs, please add the `api change` label. --> --------- Co-authored-by: Oleks V <comphead@users.noreply.github.com>
1 parent e5d9145 commit e7f7fa9

2 files changed

Lines changed: 116 additions & 3 deletions

File tree

datafusion/physical-plan/src/spill/mod.rs

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ use arrow::record_batch::RecordBatch;
4949
use arrow_data::ArrayDataBuilder;
5050

5151
use datafusion_common::config::SpillCompression;
52-
use datafusion_common::{DataFusionError, Result, exec_datafusion_err};
52+
use datafusion_common::{DataFusionError, Result, exec_datafusion_err, exec_err};
5353
use datafusion_common_runtime::SpawnedTask;
5454
use datafusion_execution::RecordBatchStream;
5555
use datafusion_execution::disk_manager::RefCountedTempFile;
@@ -121,6 +121,7 @@ impl SpillReaderStream {
121121
unreachable!()
122122
};
123123

124+
let expected_schema = Arc::clone(&self.schema);
124125
let task = SpawnedTask::spawn_blocking(move || {
125126
let file = BufReader::new(File::open(spill_file.path())?);
126127
// SAFETY: DataFusion's spill writer strictly follows Arrow IPC specifications
@@ -130,6 +131,21 @@ impl SpillReaderStream {
130131
StreamReader::try_new(file, None)?.with_skip_validation(true)
131132
};
132133

134+
// Validate the schema read from Arrow IPC file is the same as the
135+
// schema of the current `SpillManager`
136+
let actual_schema = reader.schema();
137+
138+
if actual_schema != expected_schema {
139+
return exec_err!(
140+
"Spill file schema mismatch: expected {}, got {}. \
141+
The caller must use the same SpillManager that created the spill file to read it.",
142+
expected_schema,
143+
actual_schema
144+
);
145+
}
146+
147+
// TODO: Same-schema reads from a different SpillManager still pass today.
148+
// Add a SpillManager UID to IPC metadata and validate it here as well.
133149
let next_batch = reader.next().transpose()?;
134150

135151
Ok((reader, next_batch))

datafusion/physical-plan/src/spill/spill_manager.rs

Lines changed: 99 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ impl SpillManager {
161161
}
162162

163163
/// Reads a spill file as a stream. The file must be created by the current
164-
/// `SpillManager`; otherwise behavior is undefined.
164+
/// `SpillManager`; otherwise an error will be returned.
165165
///
166166
/// Output is produced in FIFO order: the batch appended first is read first.
167167
///
@@ -247,15 +247,112 @@ fn byte_view_data_buffer_size<T: ByteViewType>(array: &GenericByteViewArray<T>)
247247

248248
#[cfg(test)]
249249
mod tests {
250+
use super::SpillManager;
251+
use crate::common::collect;
252+
use crate::metrics::{ExecutionPlanMetricsSet, SpillMetrics};
250253
use crate::spill::{get_record_batch_memory_size, spill_manager::GetSlicedSize};
251254
use arrow::datatypes::{DataType, Field, Schema};
252255
use arrow::{
253-
array::{ArrayRef, StringViewArray},
256+
array::{ArrayRef, Int32Array, StringArray, StringViewArray},
254257
record_batch::RecordBatch,
255258
};
256259
use datafusion_common::Result;
260+
use datafusion_execution::runtime_env::RuntimeEnv;
257261
use std::sync::Arc;
258262

263+
fn build_test_spill_manager(
264+
env: Arc<RuntimeEnv>,
265+
schema: Arc<Schema>,
266+
) -> SpillManager {
267+
let metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0);
268+
SpillManager::new(env, metrics, schema)
269+
}
270+
271+
fn build_writer_batch(schema: Arc<Schema>) -> Result<RecordBatch> {
272+
RecordBatch::try_new(
273+
schema,
274+
vec![
275+
Arc::new(Int32Array::from(vec![1, 2, 3])),
276+
Arc::new(StringArray::from(vec!["a", "b", "c"])),
277+
],
278+
)
279+
.map_err(Into::into)
280+
}
281+
282+
#[tokio::test]
283+
async fn test_read_spill_as_stream_from_another_spill_manager_same_schema()
284+
-> Result<()> {
285+
let env = Arc::new(RuntimeEnv::default());
286+
let writer_schema = Arc::new(Schema::new(vec![
287+
Field::new("id", DataType::Int32, false),
288+
Field::new("value", DataType::Utf8, false),
289+
]));
290+
let reader_schema = Arc::new(Schema::new(vec![
291+
Field::new("id", DataType::Int32, false),
292+
Field::new("value", DataType::Utf8, false),
293+
]));
294+
295+
let writer =
296+
build_test_spill_manager(Arc::clone(&env), Arc::clone(&writer_schema));
297+
let reader = build_test_spill_manager(env, Arc::clone(&reader_schema));
298+
let written_batch = build_writer_batch(Arc::clone(&writer_schema))?;
299+
300+
let spill_file = writer
301+
.spill_record_batch_and_finish(
302+
std::slice::from_ref(&written_batch),
303+
"writer",
304+
)?
305+
.unwrap();
306+
307+
// Same-schema reads through a different SpillManager currently pass
308+
// because only schema compatibility is validated. This is not a
309+
// supported usage pattern.
310+
let stream = reader.read_spill_as_stream(spill_file, None)?;
311+
assert_eq!(stream.schema(), reader_schema);
312+
313+
let batches = collect(stream).await?;
314+
assert_eq!(batches, vec![written_batch]);
315+
316+
Ok(())
317+
}
318+
319+
#[tokio::test]
320+
async fn test_read_spill_as_stream_from_another_spill_manager_different_schema()
321+
-> Result<()> {
322+
let env = Arc::new(RuntimeEnv::default());
323+
let writer_schema = Arc::new(Schema::new(vec![
324+
Field::new("id", DataType::Int32, false),
325+
Field::new("value", DataType::Utf8, false),
326+
]));
327+
let reader_schema = Arc::new(Schema::new(vec![
328+
Field::new("other_id", DataType::Int32, true),
329+
Field::new("other_value", DataType::Utf8, true),
330+
]));
331+
332+
let writer =
333+
build_test_spill_manager(Arc::clone(&env), Arc::clone(&writer_schema));
334+
let reader = build_test_spill_manager(env, Arc::clone(&reader_schema));
335+
let written_batch = build_writer_batch(Arc::clone(&writer_schema))?;
336+
337+
let spill_file = writer
338+
.spill_record_batch_and_finish(
339+
std::slice::from_ref(&written_batch),
340+
"writer",
341+
)?
342+
.unwrap();
343+
344+
let stream = reader.read_spill_as_stream(spill_file, None)?;
345+
let err = collect(stream)
346+
.await
347+
.expect_err("schema mismatch should fail fast");
348+
let err = err.to_string();
349+
assert!(err.contains("Spill file schema mismatch"));
350+
assert!(err.contains("expected"));
351+
assert!(err.contains("got"));
352+
353+
Ok(())
354+
}
355+
259356
#[test]
260357
fn check_sliced_size_for_string_view_array() -> Result<()> {
261358
let array_length = 50;

0 commit comments

Comments
 (0)