diff --git a/datafusion/physical-plan/src/topk/mod.rs b/datafusion/physical-plan/src/topk/mod.rs index ab9249985b863..2ceec403be8a9 100644 --- a/datafusion/physical-plan/src/topk/mod.rs +++ b/datafusion/physical-plan/src/topk/mod.rs @@ -17,12 +17,16 @@ //! TopK: Combination of Sort / LIMIT +mod native; + use arrow::{ array::{Array, AsArray}, compute::{FilterBuilder, interleave_record_batch, prep_null_mask_filter}, row::{RowConverter, Rows, SortField}, }; +use arrow_schema::SortOptions; use datafusion_expr::{ColumnarValue, Operator}; +use native::{NativeTopKHeap, find_new_native_topk_items, supports_native_topk}; use std::mem::size_of; use std::{cmp::Ordering, collections::BinaryHeap, sync::Arc}; @@ -114,12 +118,8 @@ pub struct TopK { batch_size: usize, /// sort expressions expr: LexOrdering, - /// row converter, for sort keys - row_converter: RowConverter, - /// scratch space for converting rows - scratch_rows: Rows, - /// stores the top k values and their sort key values, in order - heap: TopKHeap, + /// Heap variant: either row-based (general) or native (single primitive column) + inner: TopKInner, /// row converter, for common keys between the sort keys and the input ordering common_sort_prefix_converter: Option, /// Common sort prefix between the input and the sort expressions to allow early exit optimization @@ -132,6 +132,23 @@ pub struct TopK { pub(crate) finished: bool, } +/// Heap strategy: general row-based or specialized native encoding. +enum TopKInner { + /// General purpose: uses Arrow RowConverter for multi-column or + /// non-primitive sort keys. + Row { + row_converter: RowConverter, + scratch_rows: Rows, + heap: TopKHeap, + }, + /// Optimized path for single primitive column sorts: encodes sort + /// keys as inline `u128` values, avoiding RowConverter overhead. + Native { + sort_options: SortOptions, + heap: NativeTopKHeap, + }, +} + /// For more background, please also see the [Dynamic Filters: Passing Information Between Operators During Execution for 25x Faster Queries blog] /// /// [Dynamic Filters: Passing Information Between Operators During Execution for 25x Faster Queries blog]: https://datafusion.apache.org/blog/2025/09/10/dynamic-filters @@ -198,13 +215,26 @@ impl TopK { let reservation = MemoryConsumer::new(format!("TopK[{partition_id}]")) .register(&runtime.memory_pool); - let sort_fields = build_sort_fields(&expr, &schema)?; + // Use the native path for single primitive column sorts + let use_native = + expr.len() == 1 && supports_native_topk(&expr[0].expr.data_type(&schema)?); - // TODO there is potential to add special cases for single column sort fields - // to improve performance - let row_converter = RowConverter::new(sort_fields)?; - let scratch_rows = - row_converter.empty_rows(batch_size, ESTIMATED_BYTES_PER_ROW * batch_size); + let inner = if use_native { + TopKInner::Native { + sort_options: expr[0].options, + heap: NativeTopKHeap::new(k, batch_size), + } + } else { + let sort_fields = build_sort_fields(&expr, &schema)?; + let row_converter = RowConverter::new(sort_fields)?; + let scratch_rows = row_converter + .empty_rows(batch_size, ESTIMATED_BYTES_PER_ROW * batch_size); + TopKInner::Row { + row_converter, + scratch_rows, + heap: TopKHeap::new(k, batch_size), + } + }; let prefix_row_converter = if common_sort_prefix.is_empty() { None @@ -219,9 +249,7 @@ impl TopK { reservation, batch_size, expr, - row_converter, - scratch_rows, - heap: TopKHeap::new(k, batch_size), + inner, common_sort_prefix_converter: prefix_row_converter, common_sort_prefix: Arc::from(common_sort_prefix), finished: false, @@ -281,28 +309,69 @@ impl TopK { .map(|key| filter_predicate.filter(key).map_err(|x| x.into())) .collect::>>()?; } - // reuse existing `Rows` to avoid reallocations - let rows = &mut self.scratch_rows; - rows.clear(); - self.row_converter.append(rows, &sort_keys)?; - - let mut batch_entry = self.heap.register_batch(batch.clone()); - let replacements = match selected_rows { - Some(filter) => { - self.find_new_topk_items(filter.values().set_indices(), &mut batch_entry) + let replacements = match &mut self.inner { + TopKInner::Row { + row_converter, + scratch_rows, + heap, + } => { + // reuse existing `Rows` to avoid reallocations + scratch_rows.clear(); + row_converter.append(scratch_rows, &sort_keys)?; + + let mut batch_entry = heap.register_batch(batch.clone()); + let replacements = match selected_rows { + Some(ref filter) => find_new_topk_items_row( + heap, + scratch_rows, + filter.values().set_indices(), + &mut batch_entry, + ), + None => find_new_topk_items_row( + heap, + scratch_rows, + 0..sort_keys[0].len(), + &mut batch_entry, + ), + }; + if replacements > 0 { + heap.insert_batch_entry(batch_entry); + heap.maybe_compact()?; + } + replacements + } + TopKInner::Native { sort_options, heap } => { + let sort_key = &sort_keys[0]; + let options = *sort_options; + let mut batch_entry = heap.register_batch(batch.clone()); + let replacements = match selected_rows { + Some(ref filter) => find_new_native_topk_items( + heap, + sort_key, + options, + filter.values().set_indices(), + &mut batch_entry, + ), + None => find_new_native_topk_items( + heap, + sort_key, + options, + 0..sort_key.len(), + &mut batch_entry, + ), + }; + if replacements > 0 { + heap.insert_batch_entry(batch_entry); + heap.maybe_compact()?; + } + replacements } - None => self.find_new_topk_items(0..sort_keys[0].len(), &mut batch_entry), }; if replacements > 0 { self.metrics.row_replacements.add(replacements); - self.heap.insert_batch_entry(batch_entry); - - // conserve memory - self.heap.maybe_compact()?; - // update memory reservation self.reservation.try_resize(self.size())?; @@ -318,26 +387,20 @@ impl TopK { Ok(()) } - fn find_new_topk_items( - &mut self, - items: impl Iterator, - batch_entry: &mut RecordBatchEntry, - ) -> usize { - let mut replacements = 0; - let rows = &mut self.scratch_rows; - for (index, row) in items.zip(rows.iter()) { - match self.heap.max() { - // heap has k items, and the new row is greater than the - // current max in the heap ==> it is not a new topk - Some(max_row) if row.as_ref() >= max_row.row() => {} - // don't yet have k items or new item is lower than the currently k low values - None | Some(_) => { - self.heap.add(batch_entry, row, index); - replacements += 1; - } - } + /// Helper: access the max row's (batch_id, index) from either heap variant. + fn heap_max_ids(&self) -> Option<(u32, usize)> { + match &self.inner { + TopKInner::Row { heap, .. } => heap.max().map(|r| (r.batch_id, r.index)), + TopKInner::Native { heap, .. } => heap.max().map(|r| (r.batch_id, r.index)), + } + } + + /// Helper: get the store from either heap variant. + fn store(&self) -> &RecordBatchStore { + match &self.inner { + TopKInner::Row { heap, .. } => &heap.store, + TopKInner::Native { heap, .. } => &heap.store, } - replacements } /// Update the filter representation of our TopK heap. @@ -352,11 +415,19 @@ impl TopK { /// ``` fn update_filter(&mut self) -> Result<()> { // If the heap doesn't have k elements yet, we can't create thresholds - let Some(max_row) = self.heap.max() else { + let Some((max_batch_id, max_index)) = self.heap_max_ids() else { return Ok(()); }; - let new_threshold_row = &max_row.row; + // Build a comparable threshold representation for cross-partition dedup. + // Row variant: use the row bytes directly. + // Native variant: use the u128 key as big-endian bytes. + let new_threshold_bytes: Vec = match &self.inner { + TopKInner::Row { heap, .. } => heap.max().unwrap().row.clone(), + TopKInner::Native { heap, .. } => { + heap.max().unwrap().key.to_be_bytes().to_vec() + } + }; // Fast path: check if the current value in topk is better than what is // currently set in the filter with a read only lock @@ -367,7 +438,7 @@ impl TopK { .as_ref() .map(|current_row| { // new < current means new threshold is more selective - new_threshold_row < current_row + new_threshold_bytes.as_slice() < current_row.as_slice() }) .unwrap_or(true); // No current threshold, so we need to set one @@ -377,14 +448,18 @@ impl TopK { } // Extract scalar values BEFORE acquiring lock to reduce critical section - let thresholds = match self.heap.get_threshold_values(&self.expr)? { + let thresholds = match get_threshold_values( + self.store(), + max_batch_id, + max_index, + &self.expr, + )? { Some(t) => t, None => return Ok(()), }; // Build the filter expression OUTSIDE any synchronization let predicate = Self::build_filter_expression(&self.expr, &thresholds)?; - let new_threshold = new_threshold_row.to_vec(); // update the threshold. Since there was a lock gap, we must check if it is still the best // may have changed while we were building the expression without the lock @@ -396,8 +471,8 @@ impl TopK { match old_threshold { Some(old_threshold) => { // new threshold is still better than the old one - if new_threshold.as_slice() < old_threshold.as_slice() { - filter.threshold_row = Some(new_threshold); + if new_threshold_bytes.as_slice() < old_threshold.as_slice() { + filter.threshold_row = Some(new_threshold_bytes); } else { // some other thread updated the threshold to a better // one while we were building so there is no need to @@ -408,7 +483,7 @@ impl TopK { } None => { // No previous threshold, so we can set the new one - filter.threshold_row = Some(new_threshold); + filter.threshold_row = Some(new_threshold_bytes); } }; @@ -524,8 +599,8 @@ impl TopK { return Ok(()); }; - // Early exit if the heap is not full (`heap.max()` only returns `Some` if the heap is full). - let Some(max_topk_row) = self.heap.max() else { + // Early exit if the heap is not full (`heap_max_ids()` only returns `Some` if the heap is full). + let Some((max_batch_id, max_index)) = self.heap_max_ids() else { return Ok(()); }; @@ -538,18 +613,13 @@ impl TopK { // Retrieve the max row from the heap. let store_entry = self - .heap - .store - .get(max_topk_row.batch_id) + .store() + .get(max_batch_id) .ok_or(internal_datafusion_err!("Invalid batch id in topK heap"))?; let max_batch = &store_entry.batch; let mut heap_prefix_scratch = prefix_converter.empty_rows(1, ESTIMATED_BYTES_PER_ROW); // 1 row with capacity ESTIMATED_BYTES_PER_ROW - self.compute_common_sort_prefix( - max_batch, - max_topk_row.index, - &mut heap_prefix_scratch, - )?; + self.compute_common_sort_prefix(max_batch, max_index, &mut heap_prefix_scratch)?; // If the last row's prefix is strictly greater than the max prefix, mark as finished. if batch_prefix_scratch.row(0).as_ref() > heap_prefix_scratch.row(0).as_ref() { @@ -591,9 +661,7 @@ impl TopK { reservation: _, batch_size, expr: _, - row_converter: _, - scratch_rows: _, - mut heap, + inner, common_sort_prefix_converter: _, common_sort_prefix: _, finished: _, @@ -604,9 +672,14 @@ impl TopK { // Mark the dynamic filter as complete now that TopK processing is finished. filter.read().expr().mark_complete(); + let emitted = match inner { + TopKInner::Row { mut heap, .. } => heap.emit()?, + TopKInner::Native { mut heap, .. } => heap.emit()?, + }; + // break into record batches as needed let mut batches = vec![]; - if let Some(mut batch) = heap.emit()? { + if let Some(mut batch) = emitted { (&batch).record_output(&metrics.baseline); loop { @@ -629,9 +702,14 @@ impl TopK { /// return the size of memory used by this operator, in bytes fn size(&self) -> usize { size_of::() - + self.row_converter.size() - + self.scratch_rows.size() - + self.heap.size() + + match &self.inner { + TopKInner::Row { + row_converter, + scratch_rows, + heap, + } => row_converter.size() + scratch_rows.size() + heap.size(), + TopKInner::Native { heap, .. } => heap.size(), + } } } @@ -844,47 +922,64 @@ impl TopKHeap { + self.store.size() + self.owned_bytes } +} - fn get_threshold_values( - &self, - sort_exprs: &[PhysicalSortExpr], - ) -> Result>> { - // If the heap doesn't have k elements yet, we can't create thresholds - let max_row = match self.max() { - Some(row) => row, - None => return Ok(None), - }; +/// Row-based insertion loop (general multi-column path). +fn find_new_topk_items_row( + heap: &mut TopKHeap, + rows: &Rows, + items: impl Iterator, + batch_entry: &mut RecordBatchEntry, +) -> usize { + let mut replacements = 0; + for (index, row) in items.zip(rows.iter()) { + match heap.max() { + // heap has k items, and the new row is greater than the + // current max in the heap ==> it is not a new topk + Some(max_row) if row.as_ref() >= max_row.row() => {} + // don't yet have k items or new item is lower than the currently k low values + None | Some(_) => { + heap.add(batch_entry, row, index); + replacements += 1; + } + } + } + replacements +} - // Get the batch that contains the max row - let batch_entry = match self.store.get(max_row.batch_id) { - Some(entry) => entry, - None => return internal_err!("Invalid batch ID in TopKRow"), +/// Extract scalar threshold values from the heap's max row. +/// Works for both row-based and native heaps since it operates +/// on the underlying RecordBatch store. +fn get_threshold_values( + store: &RecordBatchStore, + batch_id: u32, + index: usize, + sort_exprs: &[PhysicalSortExpr], +) -> Result>> { + let batch_entry = match store.get(batch_id) { + Some(entry) => entry, + None => return internal_err!("Invalid batch ID in TopK heap"), + }; + + let mut scalar_values = Vec::with_capacity(sort_exprs.len()); + for sort_expr in sort_exprs { + let expr = Arc::clone(&sort_expr.expr); + let value = expr.evaluate(&batch_entry.batch.slice(index, 1))?; + + let scalar = match value { + ColumnarValue::Scalar(scalar) => scalar, + ColumnarValue::Array(array) if array.len() == 1 => { + ScalarValue::try_from_array(&array, 0)? + } + array => { + return internal_err!("Expected a scalar value, got {:?}", array); + } }; - // Extract threshold values for each sort expression - let mut scalar_values = Vec::with_capacity(sort_exprs.len()); - for sort_expr in sort_exprs { - // Extract the value for this column from the max row - let expr = Arc::clone(&sort_expr.expr); - let value = expr.evaluate(&batch_entry.batch.slice(max_row.index, 1))?; - - // Convert to scalar value - should be a single value since we're evaluating on a single row batch - let scalar = match value { - ColumnarValue::Scalar(scalar) => scalar, - ColumnarValue::Array(array) if array.len() == 1 => { - // Extract the first (and only) value from the array - ScalarValue::try_from_array(&array, 0)? - } - array => { - return internal_err!("Expected a scalar value, got {:?}", array); - } - }; - - scalar_values.push(scalar); - } - - Ok(Some(scalar_values)) + scalar_values.push(scalar); } + + Ok(Some(scalar_values)) } /// Represents one of the top K rows held in this heap. Orders @@ -951,11 +1046,11 @@ impl Ord for TopKRow { } #[derive(Debug)] -struct RecordBatchEntry { - id: u32, - batch: RecordBatch, +pub(crate) struct RecordBatchEntry { + pub id: u32, + pub batch: RecordBatch, // for this batch, how many times has it been used - uses: usize, + pub uses: usize, } /// This structure tracks [`RecordBatch`] by an id so that: @@ -963,17 +1058,17 @@ struct RecordBatchEntry { /// 1. The baches can be tracked via an id that can be copied cheaply /// 2. The total memory held by all batches is tracked #[derive(Debug)] -struct RecordBatchStore { +pub(crate) struct RecordBatchStore { /// id generator next_id: u32, /// storage - batches: HashMap, + pub batches: HashMap, /// total size of all record batches tracked by this store batches_size: usize, } impl RecordBatchStore { - fn new() -> Self { + pub(crate) fn new() -> Self { Self { next_id: 0, batches: HashMap::new(), @@ -1000,23 +1095,23 @@ impl RecordBatchStore { } /// Clear all values in this store, invalidating all previous batch ids - fn clear(&mut self) { + pub(crate) fn clear(&mut self) { self.batches.clear(); self.batches_size = 0; } - fn get(&self, id: u32) -> Option<&RecordBatchEntry> { + pub(crate) fn get(&self, id: u32) -> Option<&RecordBatchEntry> { self.batches.get(&id) } /// returns the total number of batches stored in this store - fn len(&self) -> usize { + pub(crate) fn len(&self) -> usize { self.batches.len() } /// Returns the total number of rows in batches minus the number /// which are in use - fn unused_rows(&self) -> usize { + pub(crate) fn unused_rows(&self) -> usize { self.batches .values() .map(|batch_entry| batch_entry.batch.num_rows() - batch_entry.uses) @@ -1024,7 +1119,7 @@ impl RecordBatchStore { } /// returns true if the store has nothing stored - fn is_empty(&self) -> bool { + pub(crate) fn is_empty(&self) -> bool { self.batches.is_empty() } diff --git a/datafusion/physical-plan/src/topk/native.rs b/datafusion/physical-plan/src/topk/native.rs new file mode 100644 index 0000000000000..f4acb1d47cbac --- /dev/null +++ b/datafusion/physical-plan/src/topk/native.rs @@ -0,0 +1,489 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Native (non-RowConverter) TopK heap for single primitive column sorts. +//! +//! For ORDER BY on a single primitive column, this avoids the overhead of +//! Arrow's RowConverter by encoding sort keys as inline `u128` values with +//! order-preserving encoding that handles ASC/DESC and NULLS FIRST/LAST. + +use std::cmp::Ordering; +use std::collections::BinaryHeap; +use std::mem::size_of; + +use arrow::array::{Array, ArrayRef, ArrowPrimitiveType, AsArray, RecordBatch}; +use arrow::compute::interleave_record_batch; +use arrow::datatypes::*; +use arrow_schema::SortOptions; +use datafusion_common::{HashMap, Result}; + +use super::{RecordBatchEntry, RecordBatchStore}; + +// --------------------------------------------------------------------------- +// Order-preserving encoding into u128 +// --------------------------------------------------------------------------- + +/// Encode a signed integer value into u64 preserving ascending order. +#[inline] +fn encode_signed(v: i64) -> u64 { + (v as u64) ^ (1u64 << 63) +} + +/// Encode an f32 value into u64 preserving total ascending order. +#[inline] +fn encode_f32(v: f32) -> u64 { + // f32 → f64 is lossless, reuse f64 encoding + encode_f64(v as f64) +} + +/// Encode an f64 value into u64 preserving total ascending order +/// (including NaN ordering consistent with `total_cmp`). +#[inline] +fn encode_f64(v: f64) -> u64 { + let bits = v.to_bits(); + if bits >> 63 == 1 { + // Negative: flip all bits (maps most-negative → 0) + !bits + } else { + // Non-negative: flip sign bit (maps 0.0 → 2^63) + bits ^ (1u64 << 63) + } +} + +/// Wrap an encoded u64 value with null handling and sort options into a +/// comparable u128 key. +/// +/// Layout: +/// - `NULLS FIRST`: null → `0`, non-null → `encoded + 1` +/// - `NULLS LAST`: non-null → `encoded + 1`, null → `u128::MAX` +/// - `DESC`: non-null value bits are flipped before offsetting +#[inline] +fn encode_key(is_null: bool, encoded_value: u64, options: SortOptions) -> u128 { + if is_null { + return if options.nulls_first { + 0u128 + } else { + u128::MAX + }; + } + let v = if options.descending { + !encoded_value as u128 + } else { + encoded_value as u128 + }; + // +1 ensures non-null values are always > 0 (the NULLS FIRST sentinel) + // and < u128::MAX (the NULLS LAST sentinel). + v + 1 +} + +/// Returns `true` if `dt` can be encoded into a u128 native TopK key. +pub fn supports_native_topk(dt: &DataType) -> bool { + matches!( + dt, + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Float32 + | DataType::Float64 + | DataType::Date32 + | DataType::Date64 + | DataType::Time32(_) + | DataType::Time64(_) + | DataType::Timestamp(_, _) + | DataType::Duration(_) + ) +} + +// --------------------------------------------------------------------------- +// NativeTopKRow +// --------------------------------------------------------------------------- + +/// A TopK row with an inline sort key for single-primitive-column sorts. +/// +/// Unlike [`super::TopKRow`] which heap-allocates a `Vec` of +/// RowConverter-encoded bytes, this stores the key as an inline `u128`. +#[derive(Debug, PartialEq, Eq)] +pub struct NativeTopKRow { + /// Order-preserving encoded sort key. + pub key: u128, + /// The [`RecordBatch`] this row came from (id into [`RecordBatchStore`]). + pub batch_id: u32, + /// Row index inside that batch. + pub index: usize, +} + +impl NativeTopKRow { + fn new(key: u128, batch_id: u32, index: usize) -> Self { + Self { + key, + batch_id, + index, + } + } +} + +impl Ord for NativeTopKRow { + fn cmp(&self, other: &Self) -> Ordering { + self.key.cmp(&other.key) + } +} + +impl PartialOrd for NativeTopKRow { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +// --------------------------------------------------------------------------- +// NativeTopKHeap +// --------------------------------------------------------------------------- + +/// Min-heap that keeps the smallest K encoded keys, backed by a max-heap +/// ([`BinaryHeap`] is a max-heap; the *largest* key in the heap is the +/// current threshold). +pub struct NativeTopKHeap { + k: usize, + batch_size: usize, + inner: BinaryHeap, + pub store: RecordBatchStore, +} + +impl NativeTopKHeap { + pub fn new(k: usize, batch_size: usize) -> Self { + assert!(k > 0); + Self { + k, + batch_size, + inner: BinaryHeap::new(), + store: RecordBatchStore::new(), + } + } + + pub fn register_batch(&mut self, batch: RecordBatch) -> RecordBatchEntry { + self.store.register(batch) + } + + pub fn insert_batch_entry(&mut self, entry: RecordBatchEntry) { + self.store.insert(entry) + } + + /// Returns the current threshold row (the largest / worst key in the + /// heap) if the heap already contains `k` items. + pub fn max(&self) -> Option<&NativeTopKRow> { + if self.inner.len() < self.k { + None + } else { + self.inner.peek() + } + } + + /// Insert a new row. If the heap is full, evicts the worst entry. + pub fn add(&mut self, batch_entry: &mut RecordBatchEntry, key: u128, index: usize) { + let batch_id = batch_entry.id; + batch_entry.uses += 1; + + debug_assert!(self.inner.len() <= self.k); + + if self.inner.len() == self.k { + let mut prev = self.inner.peek_mut().unwrap(); + if prev.batch_id == batch_entry.id { + batch_entry.uses -= 1; + } else { + self.store.unuse(prev.batch_id); + } + prev.key = key; + prev.batch_id = batch_id; + prev.index = index; + // PeekMut drop will sift down + } else { + self.inner.push(NativeTopKRow::new(key, batch_id, index)); + } + } + + /// Drain the heap into a single sorted [`RecordBatch`]. + pub fn emit(&mut self) -> Result> { + Ok(self.emit_with_state()?.0) + } + + pub fn emit_with_state( + &mut self, + ) -> Result<(Option, Vec)> { + let rows = std::mem::take(&mut self.inner).into_sorted_vec(); + + if self.store.is_empty() { + return Ok((None, rows)); + } + + let mut record_batches = Vec::new(); + let mut id_to_pos = HashMap::new(); + for (pos, (batch_id, batch)) in self.store.batches.iter().enumerate() { + record_batches.push(&batch.batch); + id_to_pos.insert(*batch_id, pos); + } + + let indices: Vec<_> = rows + .iter() + .map(|r| (id_to_pos[&r.batch_id], r.index)) + .collect(); + + let batch = interleave_record_batch(&record_batches, &indices)?; + Ok((Some(batch), rows)) + } + + /// Compact stored batches to reclaim memory from unused rows. + pub fn maybe_compact(&mut self) -> Result<()> { + let max_unused_rows = (20 * self.batch_size) + self.k; + let unused_rows = self.store.unused_rows(); + + if self.store.len() <= 2 || unused_rows < max_unused_rows { + return Ok(()); + } + + let num_rows = self.inner.len(); + let (new_batch, mut rows) = self.emit_with_state()?; + let Some(new_batch) = new_batch else { + return Ok(()); + }; + + self.store.clear(); + let mut batch_entry = self.register_batch(new_batch); + batch_entry.uses = num_rows; + + for (i, row) in rows.iter_mut().enumerate() { + row.batch_id = batch_entry.id; + row.index = i; + } + self.insert_batch_entry(batch_entry); + self.inner = BinaryHeap::from(rows); + + Ok(()) + } + + pub fn size(&self) -> usize { + size_of::() + + (self.inner.capacity() * size_of::()) + + self.store.size() + } +} + +// --------------------------------------------------------------------------- +// Batch-level encoding + insertion +// --------------------------------------------------------------------------- + +/// Encode values from `sort_key` and insert qualifying rows into `heap`. +/// +/// Returns the number of heap replacements. +pub fn find_new_native_topk_items( + heap: &mut NativeTopKHeap, + sort_key: &ArrayRef, + options: SortOptions, + items: impl Iterator, + batch_entry: &mut RecordBatchEntry, +) -> usize { + macro_rules! dispatch_signed { + ($arrow_ty:ty, $array:expr, $items:expr) => {{ + let typed = $array.as_primitive::<$arrow_ty>(); + find_items_inner( + heap, + typed, + |v| encode_signed(v as i64), + options, + $items, + batch_entry, + ) + }}; + } + + macro_rules! dispatch_unsigned { + ($arrow_ty:ty, $array:expr, $items:expr) => {{ + let typed = $array.as_primitive::<$arrow_ty>(); + find_items_inner(heap, typed, |v| v as u64, options, $items, batch_entry) + }}; + } + + match sort_key.data_type() { + DataType::Int8 => dispatch_signed!(Int8Type, sort_key, items), + DataType::Int16 => dispatch_signed!(Int16Type, sort_key, items), + DataType::Int32 => dispatch_signed!(Int32Type, sort_key, items), + DataType::Int64 => dispatch_signed!(Int64Type, sort_key, items), + DataType::UInt8 => dispatch_unsigned!(UInt8Type, sort_key, items), + DataType::UInt16 => dispatch_unsigned!(UInt16Type, sort_key, items), + DataType::UInt32 => dispatch_unsigned!(UInt32Type, sort_key, items), + DataType::UInt64 => dispatch_unsigned!(UInt64Type, sort_key, items), + DataType::Float32 => { + let typed = sort_key.as_primitive::(); + find_items_inner(heap, typed, encode_f32, options, items, batch_entry) + } + DataType::Float64 => { + let typed = sort_key.as_primitive::(); + find_items_inner(heap, typed, encode_f64, options, items, batch_entry) + } + // Date/Time/Timestamp/Duration are stored as i32 or i64 + DataType::Date32 => dispatch_signed!(Date32Type, sort_key, items), + DataType::Date64 => dispatch_signed!(Date64Type, sort_key, items), + DataType::Time32(TimeUnit::Second) => { + dispatch_signed!(Time32SecondType, sort_key, items) + } + DataType::Time32(TimeUnit::Millisecond) => { + dispatch_signed!(Time32MillisecondType, sort_key, items) + } + DataType::Time64(TimeUnit::Microsecond) => { + dispatch_signed!(Time64MicrosecondType, sort_key, items) + } + DataType::Time64(TimeUnit::Nanosecond) => { + dispatch_signed!(Time64NanosecondType, sort_key, items) + } + DataType::Timestamp(TimeUnit::Second, _) => { + dispatch_signed!(TimestampSecondType, sort_key, items) + } + DataType::Timestamp(TimeUnit::Millisecond, _) => { + dispatch_signed!(TimestampMillisecondType, sort_key, items) + } + DataType::Timestamp(TimeUnit::Microsecond, _) => { + dispatch_signed!(TimestampMicrosecondType, sort_key, items) + } + DataType::Timestamp(TimeUnit::Nanosecond, _) => { + dispatch_signed!(TimestampNanosecondType, sort_key, items) + } + DataType::Duration(TimeUnit::Second) => { + dispatch_signed!(DurationSecondType, sort_key, items) + } + DataType::Duration(TimeUnit::Millisecond) => { + dispatch_signed!(DurationMillisecondType, sort_key, items) + } + DataType::Duration(TimeUnit::Microsecond) => { + dispatch_signed!(DurationMicrosecondType, sort_key, items) + } + DataType::Duration(TimeUnit::Nanosecond) => { + dispatch_signed!(DurationNanosecondType, sort_key, items) + } + other => unreachable!("unsupported native TopK type: {other}"), + } +} + +/// Inner loop: iterate candidate indices, encode each value, and insert +/// into the heap when it beats the current threshold. +#[inline] +fn find_items_inner( + heap: &mut NativeTopKHeap, + array: &arrow::array::PrimitiveArray, + encode: impl Fn(T::Native) -> u64, + options: SortOptions, + items: impl Iterator, + batch_entry: &mut RecordBatchEntry, +) -> usize { + let mut replacements = 0; + for index in items { + let key = if array.is_null(index) { + if options.nulls_first { + 0u128 + } else { + u128::MAX + } + } else { + encode_key(false, encode(array.value(index)), options) + }; + + match heap.max() { + Some(max_row) if key >= max_row.key => {} + _ => { + heap.add(batch_entry, key, index); + replacements += 1; + } + } + } + replacements +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_encode_signed_order() { + // i64 order should be preserved + assert!(encode_signed(i64::MIN) < encode_signed(-1)); + assert!(encode_signed(-1) < encode_signed(0)); + assert!(encode_signed(0) < encode_signed(1)); + assert!(encode_signed(1) < encode_signed(i64::MAX)); + } + + #[test] + fn test_encode_f64_order() { + assert!(encode_f64(f64::NEG_INFINITY) < encode_f64(-1.0)); + assert!(encode_f64(-1.0) < encode_f64(-0.0)); + // -0.0 and +0.0 have different encodings (matching total_cmp) + assert!(encode_f64(-0.0) < encode_f64(0.0)); + assert!(encode_f64(0.0) < encode_f64(1.0)); + assert!(encode_f64(1.0) < encode_f64(f64::INFINITY)); + assert!(encode_f64(f64::INFINITY) < encode_f64(f64::NAN)); + } + + #[test] + fn test_encode_key_ascending_nulls_last() { + let opts = SortOptions { + descending: false, + nulls_first: false, + }; + let null_key = encode_key(true, 0, opts); + let val_key = encode_key(false, 42, opts); + assert!(val_key < null_key, "non-null should sort before null"); + } + + #[test] + fn test_encode_key_ascending_nulls_first() { + let opts = SortOptions { + descending: false, + nulls_first: true, + }; + let null_key = encode_key(true, 0, opts); + let val_key = encode_key(false, 42, opts); + assert!(null_key < val_key, "null should sort before non-null"); + } + + #[test] + fn test_encode_key_descending_nulls_first() { + let opts = SortOptions { + descending: true, + nulls_first: true, + }; + let null_key = encode_key(true, 0, opts); + let small = encode_key(false, encode_signed(1), opts); + let large = encode_key(false, encode_signed(100), opts); + assert!(null_key < small, "null first in desc"); + assert!( + large < small, + "larger value should have smaller key in desc" + ); + } + + #[test] + fn test_encode_key_descending_nulls_last() { + let opts = SortOptions { + descending: true, + nulls_first: false, + }; + let null_key = encode_key(true, 0, opts); + let val_key = encode_key(false, encode_signed(1), opts); + assert!(val_key < null_key, "null last in desc"); + } +}