Skip to content

Commit 54a5515

Browse files
Scolliqalamb
andauthored
perf(spark): use 256-entry byte-pair table in hex encoding (#21836)
Refs #15986. **Why:** `spark_hex` walked one nibble at a time — two `HEX_CHARS[i]` lookups and two `Vec::push` calls per input byte. The hot loop flattens into one indexed load and one `extend_from_slice` per byte with a precomputed table. **What changed:** added `HEX_LOOKUP_LOWER` / `HEX_LOOKUP_UPPER` as `[[u8; 2]; 256]` const tables built at compile time. Bytes path now does a single lookup + 2-byte extend per input byte. The int64 path consumes two nibbles per iteration via the same table, with a fall-through for the high nibble. Behaviour for `0`, `i64::MAX`, `i64::MIN`, `-1` preserved. **Tests:** extended `test_hex_int64` to cover edge values; new `test_hex_lookup_table_covers_all_bytes` cross-checks every entry against `format!("{:02X/x}")`; new `test_spark_hex_binary_round_trip_all_bytes` feeds all 256 byte values through `spark_hex` and verifies the result. `cargo test -p datafusion-spark --lib hex` → 8 pass. `cargo clippy --all-features --all-targets` clean. `cargo bench --no-run` builds — existing `benches/hex.rs` already covers Int64/Utf8/Utf8View/LargeUtf8/Binary/LargeBinary plus dict paths. **Not in this PR:** the #15947 review also flagged Utf8View output and dictionary-key reuse — those felt worth their own PRs to keep this focused on the per-byte hot path. --------- Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>
1 parent 6a09260 commit 54a5515

1 file changed

Lines changed: 101 additions & 16 deletions

File tree

  • datafusion/spark/src/function/math

datafusion/spark/src/function/math/hex.rs

Lines changed: 101 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -108,22 +108,50 @@ impl ScalarUDFImpl for SparkHex {
108108
}
109109
}
110110

111-
/// Hex encoding lookup tables for fast byte-to-hex conversion
112-
const HEX_CHARS_LOWER: &[u8; 16] = b"0123456789abcdef";
113-
const HEX_CHARS_UPPER: &[u8; 16] = b"0123456789ABCDEF";
111+
/// Hex encoding lookup tables for fast byte-to-hex conversion.
112+
///
113+
/// Each entry maps a full byte to its two-character hex encoding so the
114+
/// hot loop becomes one load + one two-byte extend per input byte instead
115+
/// of two nibble lookups and two pushes.
116+
const HEX_CHARS_UPPER_NIBBLES: &[u8; 16] = b"0123456789ABCDEF";
117+
const HEX_CHARS_LOWER_NIBBLES: &[u8; 16] = b"0123456789abcdef";
118+
119+
const HEX_LOOKUP_UPPER: [[u8; 2]; 256] = build_hex_lookup(HEX_CHARS_UPPER_NIBBLES);
120+
const HEX_LOOKUP_LOWER: [[u8; 2]; 256] = build_hex_lookup(HEX_CHARS_LOWER_NIBBLES);
121+
122+
const fn build_hex_lookup(nibbles: &[u8; 16]) -> [[u8; 2]; 256] {
123+
let mut table = [[0u8; 2]; 256];
124+
let mut i = 0;
125+
while i < 256 {
126+
table[i][0] = nibbles[(i >> 4) & 0xF];
127+
table[i][1] = nibbles[i & 0xF];
128+
i += 1;
129+
}
130+
table
131+
}
114132

115133
#[inline]
116134
fn hex_int64(num: i64, buffer: &mut [u8; 16]) -> &[u8] {
117135
if num == 0 {
118136
return b"0";
119137
}
120138

139+
// Walk the value two nibbles (one full byte) at a time. The buffer is
140+
// filled from the right so the high-order nibbles end up first; the
141+
// returned slice trims leading zeros automatically.
121142
let mut n = num as u64;
122143
let mut i = 16;
123-
while n != 0 {
144+
while n >= 0x10 {
145+
i -= 2;
146+
let pair = HEX_LOOKUP_UPPER[(n & 0xFF) as usize];
147+
buffer[i] = pair[0];
148+
buffer[i + 1] = pair[1];
149+
n >>= 8;
150+
}
151+
if n > 0 {
152+
// Single remaining high nibble (value 0x1..=0xF).
124153
i -= 1;
125-
buffer[i] = HEX_CHARS_UPPER[(n & 0xF) as usize];
126-
n >>= 4;
154+
buffer[i] = HEX_CHARS_UPPER_NIBBLES[n as usize];
127155
}
128156
&buffer[i..]
129157
}
@@ -140,21 +168,21 @@ where
140168
{
141169
let mut builder = StringBuilder::with_capacity(len, len * 64);
142170
let mut buffer = Vec::with_capacity(64);
143-
let hex_chars = if lowercase {
144-
HEX_CHARS_LOWER
171+
let lookup = if lowercase {
172+
&HEX_LOOKUP_LOWER
145173
} else {
146-
HEX_CHARS_UPPER
174+
&HEX_LOOKUP_UPPER
147175
};
148176

149177
for v in iter {
150178
if let Some(b) = v {
151-
buffer.clear();
152179
let bytes = b.as_ref();
180+
buffer.clear();
181+
buffer.reserve(bytes.len() * 2);
153182
for &byte in bytes {
154-
buffer.push(hex_chars[(byte >> 4) as usize]);
155-
buffer.push(hex_chars[(byte & 0x0f) as usize]);
183+
buffer.extend_from_slice(&lookup[byte as usize]);
156184
}
157-
// SAFETY: buffer contains only ASCII hex digests, which are valid UTF-8
185+
// SAFETY: buffer contains only ASCII hex digits, which are valid UTF-8.
158186
unsafe {
159187
builder.append_value(from_utf8_unchecked(&buffer));
160188
}
@@ -327,7 +355,9 @@ mod test {
327355
use std::str::from_utf8_unchecked;
328356
use std::sync::Arc;
329357

330-
use arrow::array::{DictionaryArray, Int32Array, Int64Array, StringArray};
358+
use arrow::array::{
359+
BinaryArray, DictionaryArray, Int32Array, Int64Array, StringArray,
360+
};
331361
use arrow::{
332362
array::{
333363
BinaryDictionaryBuilder, PrimitiveDictionaryBuilder, StringDictionaryBuilder,
@@ -427,19 +457,74 @@ mod test {
427457

428458
#[test]
429459
fn test_hex_int64() {
430-
let test_cases = vec![(1234, "4D2"), (-1, "FFFFFFFFFFFFFFFF")];
460+
let test_cases = vec![
461+
(0_i64, "0"),
462+
(1, "1"),
463+
(15, "F"),
464+
(16, "10"),
465+
(255, "FF"),
466+
(256, "100"),
467+
(1234, "4D2"),
468+
(i64::MAX, "7FFFFFFFFFFFFFFF"),
469+
(i64::MIN, "8000000000000000"),
470+
(-1, "FFFFFFFFFFFFFFFF"),
471+
];
431472

432473
for (num, expected) in test_cases {
433474
let mut cache = [0u8; 16];
434475
let slice = super::hex_int64(num, &mut cache);
435476

436477
unsafe {
437478
let result = from_utf8_unchecked(slice);
438-
assert_eq!(expected, result);
479+
assert_eq!(expected, result, "hex_int64({num}) mismatch");
439480
}
440481
}
441482
}
442483

484+
#[test]
485+
fn test_hex_lookup_table_covers_all_bytes() {
486+
// Cross-check the precomputed table against an independent encoder
487+
// for every possible byte value and both casings.
488+
for byte in 0u8..=255 {
489+
let upper = format!("{byte:02X}");
490+
let lower = format!("{byte:02x}");
491+
let upper_pair = super::HEX_LOOKUP_UPPER[byte as usize];
492+
let lower_pair = super::HEX_LOOKUP_LOWER[byte as usize];
493+
assert_eq!(
494+
upper.as_bytes(),
495+
&upper_pair,
496+
"upper encoding mismatch for byte 0x{byte:02X}"
497+
);
498+
assert_eq!(
499+
lower.as_bytes(),
500+
&lower_pair,
501+
"lower encoding mismatch for byte 0x{byte:02X}"
502+
);
503+
}
504+
}
505+
506+
#[test]
507+
fn test_spark_hex_binary_round_trip_all_bytes() {
508+
// Single-row binary input containing every byte value, encoded in
509+
// a single column. Catches per-byte regressions in the bytes path.
510+
let payload: Vec<u8> = (0u8..=255).collect();
511+
let bin_array = BinaryArray::from(vec![Some(payload.as_slice())]);
512+
513+
let result =
514+
super::spark_hex(&[ColumnarValue::Array(Arc::new(bin_array))]).unwrap();
515+
let array = match result {
516+
ColumnarValue::Array(array) => array,
517+
_ => panic!("Expected array"),
518+
};
519+
let strings = as_string_array(&array);
520+
let mut expected = String::with_capacity(512);
521+
for byte in 0u8..=255 {
522+
use std::fmt::Write;
523+
write!(expected, "{byte:02X}").unwrap();
524+
}
525+
assert_eq!(strings.value(0), expected);
526+
}
527+
443528
#[test]
444529
fn test_spark_hex_int64() {
445530
let int_array = Int64Array::from(vec![Some(1), Some(2), None, Some(3)]);

0 commit comments

Comments
 (0)