forked from apache/datafusion
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcustom_file_casts.rs
More file actions
202 lines (184 loc) · 7.73 KB
/
custom_file_casts.rs
File metadata and controls
202 lines (184 loc) · 7.73 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
//! See `main.rs` for how to run it.
use std::sync::Arc;
use arrow::array::{RecordBatch, record_batch};
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use datafusion::assert_batches_eq;
use datafusion::common::Result;
use datafusion::common::not_impl_err;
use datafusion::common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion::datasource::listing::{
ListingTable, ListingTableConfig, ListingTableConfigExt, ListingTableUrl,
};
use datafusion::execution::context::SessionContext;
use datafusion::execution::object_store::ObjectStoreUrl;
use datafusion::parquet::arrow::ArrowWriter;
use datafusion::physical_expr::PhysicalExpr;
use datafusion::physical_expr::expressions::CastExpr;
use datafusion::prelude::SessionConfig;
use datafusion_physical_expr_adapter::{
DefaultPhysicalExprAdapterFactory, PhysicalExprAdapter, PhysicalExprAdapterFactory,
};
use object_store::memory::InMemory;
use object_store::path::Path;
use object_store::{ObjectStore, ObjectStoreExt, PutPayload};
// Example showing how to implement custom casting rules to adapt file schemas.
// This example enforces strictly widening casts: if the file type is Int64 and
// the table type is Int32, it errors before reading the data. Without this
// custom cast rule DataFusion would apply the narrowing cast and might only
// error after reading a row that it could not cast.
pub async fn custom_file_casts() -> Result<()> {
println!("=== Creating example data ===");
// Create a logical / table schema with an Int32 column (nullable)
let logical_schema =
Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, true)]));
// Create some data that can be cast (Int16 -> Int32 is widening) and some that cannot (Int64 -> Int32 is narrowing)
let store = Arc::new(InMemory::new()) as Arc<dyn ObjectStore>;
let path = Path::from("good.parquet");
let batch = record_batch!(("id", Int16, [1, 2, 3]))?;
write_data(&store, &path, &batch).await?;
let path = Path::from("bad.parquet");
let batch = record_batch!(("id", Int64, [1, 2, 3]))?;
write_data(&store, &path, &batch).await?;
// Set up query execution
let mut cfg = SessionConfig::new();
// Turn on filter pushdown so that the PhysicalExprAdapter is used
cfg.options_mut().execution.parquet.pushdown_filters = true;
let ctx = SessionContext::new_with_config(cfg);
ctx.runtime_env()
.register_object_store(ObjectStoreUrl::parse("memory://")?.as_ref(), store);
// Register our good and bad files via ListingTable
let listing_table_config =
ListingTableConfig::new(ListingTableUrl::parse("memory:///good.parquet")?)
.infer_options(&ctx.state())
.await?
.with_schema(Arc::clone(&logical_schema))
.with_expr_adapter_factory(Arc::new(
CustomCastPhysicalExprAdapterFactory::new(Arc::new(
DefaultPhysicalExprAdapterFactory,
)),
));
let table = ListingTable::try_new(listing_table_config).unwrap();
ctx.register_table("good_table", Arc::new(table))?;
let listing_table_config =
ListingTableConfig::new(ListingTableUrl::parse("memory:///bad.parquet")?)
.infer_options(&ctx.state())
.await?
.with_schema(Arc::clone(&logical_schema))
.with_expr_adapter_factory(Arc::new(
CustomCastPhysicalExprAdapterFactory::new(Arc::new(
DefaultPhysicalExprAdapterFactory,
)),
));
let table = ListingTable::try_new(listing_table_config).unwrap();
ctx.register_table("bad_table", Arc::new(table))?;
println!("\n=== File with narrower schema is cast ===");
let query = "SELECT id FROM good_table WHERE id > 1";
println!("Query: {query}");
let batches = ctx.sql(query).await?.collect().await?;
#[rustfmt::skip]
let expected = [
"+----+",
"| id |",
"+----+",
"| 2 |",
"| 3 |",
"+----+",
];
arrow::util::pretty::print_batches(&batches)?;
assert_batches_eq!(expected, &batches);
println!("\n=== File with wider schema errors ===");
let query = "SELECT id FROM bad_table WHERE id > 1";
println!("Query: {query}");
match ctx.sql(query).await?.collect().await {
Ok(_) => panic!("Expected error for narrowing cast, but query succeeded"),
Err(e) => {
println!("Caught expected error: {e}");
}
}
Ok(())
}
async fn write_data(
store: &dyn ObjectStore,
path: &Path,
batch: &RecordBatch,
) -> Result<()> {
let mut buf = vec![];
let mut writer = ArrowWriter::try_new(&mut buf, batch.schema(), None)?;
writer.write(batch)?;
writer.close()?;
let payload = PutPayload::from_bytes(buf.into());
store.put(path, payload).await?;
Ok(())
}
/// Factory for creating custom cast physical expression adapters
#[derive(Debug)]
struct CustomCastPhysicalExprAdapterFactory {
inner: Arc<dyn PhysicalExprAdapterFactory>,
}
impl CustomCastPhysicalExprAdapterFactory {
fn new(inner: Arc<dyn PhysicalExprAdapterFactory>) -> Self {
Self { inner }
}
}
impl PhysicalExprAdapterFactory for CustomCastPhysicalExprAdapterFactory {
fn create(
&self,
logical_file_schema: SchemaRef,
physical_file_schema: SchemaRef,
) -> Result<Arc<dyn PhysicalExprAdapter>> {
let inner = self
.inner
.create(logical_file_schema, Arc::clone(&physical_file_schema))?;
Ok(Arc::new(CustomCastsPhysicalExprAdapter {
physical_file_schema,
inner,
}))
}
}
/// Custom `PhysicalExprAdapter` that wraps the default adapter and rejects
/// narrowing file-schema casts.
#[derive(Debug, Clone)]
struct CustomCastsPhysicalExprAdapter {
physical_file_schema: SchemaRef,
inner: Arc<dyn PhysicalExprAdapter>,
}
impl PhysicalExprAdapter for CustomCastsPhysicalExprAdapter {
fn rewrite(&self, mut expr: Arc<dyn PhysicalExpr>) -> Result<Arc<dyn PhysicalExpr>> {
// First delegate to the inner adapter to handle standard schema adaptation
// and discover any necessary casts.
expr = self.inner.rewrite(expr)?;
// Now apply custom casting rules or swap CastExprs for a custom cast
// kernel / expression. For example, DataFusion Comet has a custom cast
// kernel in its native Spark expression implementation.
expr.transform(|expr| {
if let Some(cast) = expr.as_any().downcast_ref::<CastExpr>() {
let input_data_type =
cast.expr().data_type(&self.physical_file_schema)?;
let output_data_type = cast.target_field().data_type();
if !cast.is_bigger_cast(&input_data_type) {
return not_impl_err!(
"Unsupported CAST from {input_data_type} to {output_data_type}"
);
}
}
Ok(Transformed::no(expr))
})
.data()
}
}