Skip to content

Commit 3763ad4

Browse files
feat: Additional Canonical Extension Types (#21291)
## Which issue does this PR close? - Related to #21144 . ## Rationale for this change Implements `DFExtensionType` for the other canonical extension types (except Parquet types). ## What changes are included in this PR? - Implement `DFExtensionType` for: - `arrow.bool8` (custom formatting with true/false) - `arrow.fixed_shape_tensor` (nu custom formatting for now) - `arrow.json` (no custom formatting for now) - `arrow.opaque` (no custom formatting for now) - `arrow.timestamp_with_offset` (custom formatting as timestamp) - `arrow.variable_shape_tensor` (no custom formatting) - Introduce a wrapper struct `DFUuid` for `Uuid` so that it is consistent with the other extension types. I don't know whether we truly need the wrapper structs for extension types that only have a single valid storage type based on their metadata (e.g., Bool8, Uuid) . Open for any opinions. @paleolimbot Is this the kind of wrapper structs you imagined? Should we also add end-to-end tests for the other extension types? Currently, we only have the UUID one from the last PR. I think for variant supports it's best to wait for general Variant support in DataFusion as this is currently done in https://github.com/datafusion-contrib/datafusion-variant . ## Are these changes tested? Custom formatters are tested. The rest is mainly boiler plate code. ## Are there any user-facing changes? Yes more extension types and renames `Uuid` to `DFUuid`. If we plan to keep our own wrapper structs, we should merge that before releasing a version with the old `Uuid` so we have no breaking changes.
1 parent 961c5fc commit 3763ad4

12 files changed

Lines changed: 903 additions & 171 deletions

File tree

datafusion-examples/examples/extension_types/temperature.rs

Lines changed: 64 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ use arrow::array::{
2121
};
2222
use arrow::datatypes::{Float32Type, Float64Type};
2323
use arrow::util::display::{ArrayFormatter, DisplayIndex, FormatOptions, FormatResult};
24-
use arrow_schema::extension::ExtensionType;
24+
use arrow_schema::extension::{EXTENSION_TYPE_METADATA_KEY, EXTENSION_TYPE_NAME_KEY};
2525
use arrow_schema::{ArrowError, DataType, Field, Schema, SchemaRef};
2626
use datafusion::dataframe::DataFrame;
2727
use datafusion::error::Result;
@@ -30,8 +30,9 @@ use datafusion::prelude::SessionContext;
3030
use datafusion_common::internal_err;
3131
use datafusion_common::types::DFExtensionType;
3232
use datafusion_expr::registry::{
33-
DefaultExtensionTypeRegistration, ExtensionTypeRegistry, MemoryExtensionTypeRegistry,
33+
ExtensionTypeRegistration, ExtensionTypeRegistry, MemoryExtensionTypeRegistry,
3434
};
35+
use std::collections::HashMap;
3536
use std::fmt::{Display, Write};
3637
use std::sync::Arc;
3738

@@ -50,13 +51,15 @@ fn create_session_context() -> Result<SessionContext> {
5051
let registry = MemoryExtensionTypeRegistry::new_empty();
5152

5253
// The registration creates a new instance of the extension type with the deserialized metadata.
53-
let temp_registration =
54-
DefaultExtensionTypeRegistration::new_arc(|storage_type, metadata| {
55-
Ok(TemperatureExtensionType::new(
56-
storage_type.clone(),
57-
metadata,
58-
))
59-
});
54+
let temp_registration = ExtensionTypeRegistration::new_arc(
55+
TemperatureExtensionType::NAME,
56+
|storage_type, metadata| {
57+
Ok(Arc::new(TemperatureExtensionType::try_new(
58+
storage_type,
59+
TemperatureUnit::deserialize(metadata)?,
60+
)?))
61+
},
62+
);
6063
registry.add_extension_type_registration(temp_registration)?;
6164

6265
let state = SessionStateBuilder::default()
@@ -96,26 +99,15 @@ async fn register_temperature_table(ctx: &SessionContext) -> Result<DataFrame> {
9699
fn example_schema() -> SchemaRef {
97100
Arc::new(Schema::new(vec![
98101
Field::new("city", DataType::Utf8, false),
99-
Field::new("celsius", DataType::Float64, false).with_extension_type(
100-
TemperatureExtensionType::new(DataType::Float64, TemperatureUnit::Celsius),
101-
),
102-
Field::new("fahrenheit", DataType::Float64, false).with_extension_type(
103-
TemperatureExtensionType::new(DataType::Float64, TemperatureUnit::Fahrenheit),
104-
),
105-
Field::new("kelvin", DataType::Float32, false).with_extension_type(
106-
TemperatureExtensionType::new(DataType::Float32, TemperatureUnit::Kelvin),
107-
),
102+
Field::new("celsius", DataType::Float64, false)
103+
.with_metadata(create_metadata(TemperatureUnit::Celsius)),
104+
Field::new("fahrenheit", DataType::Float64, false)
105+
.with_metadata(create_metadata(TemperatureUnit::Fahrenheit)),
106+
Field::new("kelvin", DataType::Float32, false)
107+
.with_metadata(create_metadata(TemperatureUnit::Kelvin)),
108108
]))
109109
}
110110

111-
/// Represents the unit of a temperature reading.
112-
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
113-
pub enum TemperatureUnit {
114-
Celsius,
115-
Fahrenheit,
116-
Kelvin,
117-
}
118-
119111
/// Represents a float that semantically represents a temperature. The temperature can be one of
120112
/// the supported [`TemperatureUnit`]s.
121113
///
@@ -143,46 +135,57 @@ pub struct TemperatureExtensionType {
143135
}
144136

145137
impl TemperatureExtensionType {
138+
/// The name of the extension type.
139+
pub const NAME: &'static str = "custom.temperature";
140+
146141
/// Creates a new [`TemperatureExtensionType`].
147-
pub fn new(storage_type: DataType, temperature_unit: TemperatureUnit) -> Self {
148-
Self {
149-
storage_type,
150-
temperature_unit,
142+
pub fn try_new(
143+
storage_type: &DataType,
144+
temperature_unit: TemperatureUnit,
145+
) -> Result<Self, ArrowError> {
146+
match storage_type {
147+
DataType::Float32 | DataType::Float64 => {}
148+
_ => {
149+
return Err(ArrowError::InvalidArgumentError(format!(
150+
"Invalid data type: {storage_type} for temperature type, expected Float32 or Float64",
151+
)));
152+
}
151153
}
154+
155+
let result = Self {
156+
storage_type: storage_type.clone(),
157+
temperature_unit,
158+
};
159+
Ok(result)
152160
}
153161
}
154162

155-
/// Implementation of [`ExtensionType`] for [`TemperatureExtensionType`].
156-
///
157-
/// This implements the arrow-rs trait for reading, writing, and validating extension types.
158-
impl ExtensionType for TemperatureExtensionType {
159-
/// Arrow extension type name that is stored in the `ARROW:extension:name` field.
160-
const NAME: &'static str = "custom.temperature";
161-
type Metadata = TemperatureUnit;
162-
163-
fn metadata(&self) -> &Self::Metadata {
164-
&self.temperature_unit
165-
}
163+
/// Represents the unit of a temperature reading.
164+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
165+
pub enum TemperatureUnit {
166+
Celsius,
167+
Fahrenheit,
168+
Kelvin,
169+
}
166170

171+
impl TemperatureUnit {
167172
/// Arrow extension type metadata is encoded as a string and stored using the
168173
/// `ARROW:extension:metadata` key. As we only store the name of the unit, a simple string
169174
/// suffices. Extension types can store more complex metadata using serialization formats like
170175
/// JSON.
171-
fn serialize_metadata(&self) -> Option<String> {
172-
let s = match self.temperature_unit {
176+
pub fn serialize(self) -> String {
177+
let result = match self {
173178
TemperatureUnit::Celsius => "celsius",
174179
TemperatureUnit::Fahrenheit => "fahrenheit",
175180
TemperatureUnit::Kelvin => "kelvin",
176181
};
177-
Some(s.to_string())
182+
result.to_owned()
178183
}
179184

180-
/// Inverse operation of [`Self::serialize_metadata`]. This creates the [`TemperatureUnit`]
185+
/// Inverse operation of [`TemperatureUnit::serialize`]. This creates the [`TemperatureUnit`]
181186
/// value from the serialized string.
182-
fn deserialize_metadata(
183-
metadata: Option<&str>,
184-
) -> std::result::Result<Self::Metadata, ArrowError> {
185-
match metadata {
187+
pub fn deserialize(value: Option<&str>) -> std::result::Result<Self, ArrowError> {
188+
match value {
186189
Some("celsius") => Ok(TemperatureUnit::Celsius),
187190
Some("fahrenheit") => Ok(TemperatureUnit::Fahrenheit),
188191
Some("kelvin") => Ok(TemperatureUnit::Kelvin),
@@ -194,28 +197,18 @@ impl ExtensionType for TemperatureExtensionType {
194197
)),
195198
}
196199
}
200+
}
197201

198-
/// Checks that the extension type supports a given [`DataType`].
199-
fn supports_data_type(
200-
&self,
201-
data_type: &DataType,
202-
) -> std::result::Result<(), ArrowError> {
203-
match data_type {
204-
DataType::Float32 | DataType::Float64 => Ok(()),
205-
_ => Err(ArrowError::InvalidArgumentError(format!(
206-
"Invalid data type: {data_type} for temperature type, expected Float32 or Float64",
207-
))),
208-
}
209-
}
210-
211-
fn try_new(
212-
data_type: &DataType,
213-
metadata: Self::Metadata,
214-
) -> std::result::Result<Self, ArrowError> {
215-
let instance = Self::new(data_type.clone(), metadata);
216-
instance.supports_data_type(data_type)?;
217-
Ok(instance)
218-
}
202+
/// This creates a metadata map for the temperature type. Another way of writing the metadata can be
203+
/// implemented using arrow-rs' [`ExtensionType`](arrow_schema::extension::ExtensionType) trait.
204+
fn create_metadata(unit: TemperatureUnit) -> HashMap<String, String> {
205+
HashMap::from([
206+
(
207+
EXTENSION_TYPE_NAME_KEY.to_owned(),
208+
TemperatureExtensionType::NAME.to_owned(),
209+
),
210+
(EXTENSION_TYPE_METADATA_KEY.to_owned(), unit.serialize()),
211+
])
219212
}
220213

221214
/// Implementation of [`DFExtensionType`] for [`TemperatureExtensionType`].
@@ -227,7 +220,7 @@ impl DFExtensionType for TemperatureExtensionType {
227220
}
228221

229222
fn serialize_metadata(&self) -> Option<String> {
230-
ExtensionType::serialize_metadata(self)
223+
Some(self.temperature_unit.serialize())
231224
}
232225

233226
fn create_array_formatter<'fmt>(
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use crate::Result;
19+
use crate::error::_internal_err;
20+
use crate::types::extension::DFExtensionType;
21+
use arrow::array::{Array, Int8Array};
22+
use arrow::datatypes::DataType;
23+
use arrow::util::display::{ArrayFormatter, DisplayIndex, FormatOptions, FormatResult};
24+
use arrow_schema::extension::{Bool8, ExtensionType};
25+
use std::fmt::Write;
26+
27+
/// Defines the extension type logic for the canonical `arrow.bool8` extension type. This extension
28+
/// type allows storing a Boolean value in a single byte, instead of a single bit.
29+
///
30+
/// See [`DFExtensionType`] for information on DataFusion's extension type mechanism. See also
31+
/// [`Bool8`] for the implementation of arrow-rs, which this type uses internally.
32+
///
33+
/// <https://arrow.apache.org/docs/format/CanonicalExtensions.html#bit-boolean>
34+
#[derive(Debug, Clone)]
35+
pub struct DFBool8(Bool8);
36+
37+
impl DFBool8 {
38+
/// Creates a new [`DFBool8`], validating that the storage type is compatible with the
39+
/// extension type.
40+
///
41+
/// Even though [`DFBool8`] only supports a single storage type ([`DataType::Int8`]), passing-in
42+
/// the storage type allows conveniently validating whether this extension type is compatible
43+
/// with a given [`DataType`].
44+
pub fn try_new(
45+
data_type: &DataType,
46+
metadata: <Bool8 as ExtensionType>::Metadata,
47+
) -> Result<Self> {
48+
// Validates the storage type
49+
Ok(Self(<Bool8 as ExtensionType>::try_new(
50+
data_type, metadata,
51+
)?))
52+
}
53+
}
54+
55+
impl DFExtensionType for DFBool8 {
56+
fn storage_type(&self) -> DataType {
57+
DataType::Int8
58+
}
59+
60+
fn serialize_metadata(&self) -> Option<String> {
61+
self.0.serialize_metadata()
62+
}
63+
64+
fn create_array_formatter<'fmt>(
65+
&self,
66+
array: &'fmt dyn Array,
67+
options: &FormatOptions<'fmt>,
68+
) -> Result<Option<ArrayFormatter<'fmt>>> {
69+
if array.data_type() != &DataType::Int8 {
70+
return _internal_err!("Wrong array type for Bool8");
71+
}
72+
73+
let display_index = Bool8ValueDisplayIndex {
74+
array: array.as_any().downcast_ref().unwrap(),
75+
null_str: options.null(),
76+
};
77+
Ok(Some(ArrayFormatter::new(
78+
Box::new(display_index),
79+
options.safe(),
80+
)))
81+
}
82+
}
83+
84+
/// Pretty printer for binary bool8 values.
85+
#[derive(Debug, Clone, Copy)]
86+
struct Bool8ValueDisplayIndex<'a> {
87+
array: &'a Int8Array,
88+
null_str: &'a str,
89+
}
90+
91+
impl DisplayIndex for Bool8ValueDisplayIndex<'_> {
92+
fn write(&self, idx: usize, f: &mut dyn Write) -> FormatResult {
93+
if self.array.is_null(idx) {
94+
write!(f, "{}", self.null_str)?;
95+
return Ok(());
96+
}
97+
98+
let bytes = self.array.value(idx);
99+
write!(f, "{}", bytes != 0)?;
100+
Ok(())
101+
}
102+
}
103+
104+
#[cfg(test)]
105+
mod tests {
106+
use super::*;
107+
108+
#[test]
109+
pub fn test_pretty_bool8() {
110+
let values = Int8Array::from_iter([Some(0), Some(1), Some(-20), None]);
111+
112+
let extension_type = DFBool8(Bool8 {});
113+
let formatter = extension_type
114+
.create_array_formatter(&values, &FormatOptions::default().with_null("NULL"))
115+
.unwrap()
116+
.unwrap();
117+
118+
assert_eq!(formatter.value(0).to_string(), "false");
119+
assert_eq!(formatter.value(1).to_string(), "true");
120+
assert_eq!(formatter.value(2).to_string(), "true");
121+
assert_eq!(formatter.value(3).to_string(), "NULL");
122+
}
123+
}
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use crate::Result;
19+
use crate::types::extension::DFExtensionType;
20+
use arrow::datatypes::DataType;
21+
use arrow_schema::extension::{ExtensionType, FixedShapeTensor};
22+
23+
/// Defines the extension type logic for the canonical `arrow.fixed_shape_tensor` extension type.
24+
/// This extension type can be used to store a [tensor](https://en.wikipedia.org/wiki/Tensor) of
25+
/// a fixed shape.
26+
///
27+
/// See [`DFExtensionType`] for information on DataFusion's extension type mechanism. See also
28+
/// [`FixedShapeTensor`] for the implementation of arrow-rs, which this type uses internally.
29+
///
30+
/// <https://arrow.apache.org/docs/format/CanonicalExtensions.html#fixed-shape-tensor>
31+
#[derive(Debug, Clone)]
32+
pub struct DFFixedShapeTensor {
33+
inner: FixedShapeTensor,
34+
/// The storage type of the tensor.
35+
///
36+
/// While we could reconstruct the storage type from the inner [`FixedShapeTensor`], we may
37+
/// choose a different name for the field within the [`DataType::FixedSizeList`] which can
38+
/// cause problems down the line (e.g., checking for equality).
39+
storage_type: DataType,
40+
}
41+
42+
impl DFFixedShapeTensor {
43+
/// Creates a new [`DFFixedShapeTensor`], validating that the storage type is compatible with
44+
/// the extension type.
45+
pub fn try_new(
46+
data_type: &DataType,
47+
metadata: <FixedShapeTensor as ExtensionType>::Metadata,
48+
) -> Result<Self> {
49+
Ok(Self {
50+
inner: <FixedShapeTensor as ExtensionType>::try_new(data_type, metadata)?,
51+
storage_type: data_type.clone(),
52+
})
53+
}
54+
}
55+
56+
impl DFExtensionType for DFFixedShapeTensor {
57+
fn storage_type(&self) -> DataType {
58+
self.storage_type.clone()
59+
}
60+
61+
fn serialize_metadata(&self) -> Option<String> {
62+
self.inner.serialize_metadata()
63+
}
64+
}

0 commit comments

Comments
 (0)