Skip to content

Commit 643db7a

Browse files
crm26claude
andauthored
feat: add cosine_distance scalar function (#21542)
## Summary - Adds `cosine_distance(array1, array2)` / `list_cosine_distance` — computes cosine distance (1 - cosine similarity) between two numeric arrays - Introduces shared `vector_math.rs` primitives (`dot_product_f64`, `magnitude_f64`, `convert_to_f64_array`) for reuse by follow-on vector functions - Returns NULL for zero-magnitude vectors; errors on mismatched lengths - Supports List, LargeList, and FixedSizeList with any numeric element type Part of #21536 — first in a series of split PRs (replacing #21371). ## Test plan - [x] Unit tests: identical, orthogonal, opposite, 45-degree, zero-magnitude, mismatched-length, NULL, multi-row - [x] sqllogictest: `cosine_distance.slt` covering all edge cases including empty arrays, LargeList, integer coercion, alias, return type - [x] Full slt suite (426/426 pass) - [x] `cargo clippy`, `cargo fmt`, `taplo`, `prettier`, `cargo machete` — all clean 🤖 Generated with [Claude Code](https://claude.com/claude-code) --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 82abcbd commit 643db7a

4 files changed

Lines changed: 414 additions & 0 deletions

File tree

Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
//! [`ScalarUDFImpl`] definitions for cosine_distance function.
19+
20+
use crate::utils::make_scalar_function;
21+
use arrow::array::{Array, ArrayRef, Float64Array, OffsetSizeTrait};
22+
use arrow::datatypes::{
23+
DataType,
24+
DataType::{FixedSizeList, LargeList, List, Null},
25+
Field,
26+
};
27+
use datafusion_common::cast::{as_float64_array, as_generic_list_array};
28+
use datafusion_common::utils::{ListCoercion, coerced_type_with_base_type_only};
29+
use datafusion_common::{
30+
Result, exec_err, internal_err, plan_err, utils::take_function_args,
31+
};
32+
use datafusion_expr::{
33+
ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
34+
Volatility,
35+
};
36+
use datafusion_macros::user_doc;
37+
use std::sync::Arc;
38+
39+
make_udf_expr_and_func!(
40+
CosineDistance,
41+
cosine_distance,
42+
array1 array2,
43+
"returns the cosine distance between two numeric arrays.",
44+
cosine_distance_udf
45+
);
46+
47+
#[user_doc(
48+
doc_section(label = "Array Functions"),
49+
description = "Returns the cosine distance between two input arrays of equal length. The cosine distance is defined as 1 - cosine_similarity, i.e. `1 - dot(a,b) / (||a|| * ||b||)`. Returns NULL if either array is NULL or contains only zeros.",
50+
syntax_example = "cosine_distance(array1, array2)",
51+
sql_example = r#"```sql
52+
> select cosine_distance([1.0, 0.0], [0.0, 1.0]);
53+
+-----------------------------------------------+
54+
| cosine_distance(List([1.0,0.0]),List([0.0,1.0])) |
55+
+-----------------------------------------------+
56+
| 1.0 |
57+
+-----------------------------------------------+
58+
```"#,
59+
argument(
60+
name = "array1",
61+
description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
62+
),
63+
argument(
64+
name = "array2",
65+
description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
66+
)
67+
)]
68+
#[derive(Debug, PartialEq, Eq, Hash)]
69+
pub struct CosineDistance {
70+
signature: Signature,
71+
}
72+
73+
impl Default for CosineDistance {
74+
fn default() -> Self {
75+
Self::new()
76+
}
77+
}
78+
79+
impl CosineDistance {
80+
pub fn new() -> Self {
81+
Self {
82+
signature: Signature::user_defined(Volatility::Immutable),
83+
}
84+
}
85+
}
86+
87+
impl ScalarUDFImpl for CosineDistance {
88+
fn name(&self) -> &str {
89+
"cosine_distance"
90+
}
91+
92+
fn signature(&self) -> &Signature {
93+
&self.signature
94+
}
95+
96+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
97+
Ok(DataType::Float64)
98+
}
99+
100+
fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
101+
let [_, _] = take_function_args(self.name(), arg_types)?;
102+
let coercion = Some(&ListCoercion::FixedSizedListToList);
103+
104+
for arg_type in arg_types {
105+
if !matches!(arg_type, Null | List(_) | LargeList(_) | FixedSizeList(..)) {
106+
return plan_err!("{} does not support type {arg_type}", self.name());
107+
}
108+
}
109+
110+
// If any input is `LargeList`, both sides must be widened to `LargeList`
111+
// so the runtime dispatch in `cosine_distance_inner` sees a homogeneous
112+
// pair. Follows the pattern in `ArrayConcat::coerce_types`.
113+
let any_large_list = arg_types.iter().any(|t| matches!(t, LargeList(_)));
114+
115+
let coerced = arg_types
116+
.iter()
117+
.map(|arg_type| {
118+
if matches!(arg_type, Null) {
119+
let field = Arc::new(Field::new_list_field(DataType::Float64, true));
120+
return if any_large_list {
121+
LargeList(field)
122+
} else {
123+
List(field)
124+
};
125+
}
126+
let coerced = coerced_type_with_base_type_only(
127+
arg_type,
128+
&DataType::Float64,
129+
coercion,
130+
);
131+
match coerced {
132+
List(field) if any_large_list => LargeList(field),
133+
other => other,
134+
}
135+
})
136+
.collect();
137+
138+
Ok(coerced)
139+
}
140+
141+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
142+
make_scalar_function(cosine_distance_inner)(&args.args)
143+
}
144+
145+
fn documentation(&self) -> Option<&Documentation> {
146+
self.doc()
147+
}
148+
}
149+
150+
fn cosine_distance_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
151+
let [array1, array2] = take_function_args("cosine_distance", args)?;
152+
match (array1.data_type(), array2.data_type()) {
153+
(List(_), List(_)) => general_cosine_distance::<i32>(args),
154+
(LargeList(_), LargeList(_)) => general_cosine_distance::<i64>(args),
155+
(arg_type1, arg_type2) => internal_err!(
156+
"cosine_distance received unexpected types after coercion: {arg_type1} and {arg_type2}"
157+
),
158+
}
159+
}
160+
161+
fn general_cosine_distance<O: OffsetSizeTrait>(arrays: &[ArrayRef]) -> Result<ArrayRef> {
162+
let list_array1 = as_generic_list_array::<O>(&arrays[0])?;
163+
let list_array2 = as_generic_list_array::<O>(&arrays[1])?;
164+
165+
let values1 = as_float64_array(list_array1.values())?;
166+
let values2 = as_float64_array(list_array2.values())?;
167+
let offsets1 = list_array1.value_offsets();
168+
let offsets2 = list_array2.value_offsets();
169+
170+
let mut builder = Float64Array::builder(list_array1.len());
171+
for row in 0..list_array1.len() {
172+
if list_array1.is_null(row) || list_array2.is_null(row) {
173+
builder.append_null();
174+
continue;
175+
}
176+
177+
let start1 = offsets1[row].as_usize();
178+
let end1 = offsets1[row + 1].as_usize();
179+
let start2 = offsets2[row].as_usize();
180+
let end2 = offsets2[row + 1].as_usize();
181+
let len1 = end1 - start1;
182+
let len2 = end2 - start2;
183+
184+
if len1 != len2 {
185+
return exec_err!(
186+
"cosine_distance requires both list inputs to have the same length, got {len1} and {len2}"
187+
);
188+
}
189+
190+
let slice1 = values1.slice(start1, len1);
191+
let slice2 = values2.slice(start2, len2);
192+
if slice1.null_count() != 0 || slice2.null_count() != 0 {
193+
builder.append_null();
194+
continue;
195+
}
196+
197+
let vals1 = slice1.values();
198+
let vals2 = slice2.values();
199+
200+
let mut dot = 0.0;
201+
let mut sq1 = 0.0;
202+
let mut sq2 = 0.0;
203+
for i in 0..len1 {
204+
let a = vals1[i];
205+
let b = vals2[i];
206+
dot += a * b;
207+
sq1 += a * a;
208+
sq2 += b * b;
209+
}
210+
211+
if sq1 == 0.0 || sq2 == 0.0 {
212+
builder.append_null();
213+
} else {
214+
builder.append_value(1.0 - dot / (sq1.sqrt() * sq2.sqrt()));
215+
}
216+
}
217+
218+
Ok(Arc::new(builder.finish()) as ArrayRef)
219+
}

datafusion/functions-nested/src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ pub mod array_has;
4242
pub mod arrays_zip;
4343
pub mod cardinality;
4444
pub mod concat;
45+
pub mod cosine_distance;
4546
pub mod dimension;
4647
pub mod distance;
4748
pub mod empty;
@@ -87,6 +88,7 @@ pub mod expr_fn {
8788
pub use super::concat::array_append;
8889
pub use super::concat::array_concat;
8990
pub use super::concat::array_prepend;
91+
pub use super::cosine_distance::cosine_distance;
9092
pub use super::dimension::array_dims;
9193
pub use super::dimension::array_ndims;
9294
pub use super::distance::array_distance;
@@ -153,6 +155,7 @@ pub fn all_default_nested_functions() -> Vec<Arc<ScalarUDF>> {
153155
array_has::array_has_any_udf(),
154156
empty::array_empty_udf(),
155157
length::array_length_udf(),
158+
cosine_distance::cosine_distance_udf(),
156159
distance::array_distance_udf(),
157160
flatten::flatten_udf(),
158161
min_max::array_max_udf(),

0 commit comments

Comments
 (0)