Skip to content

Commit 0851e0d

Browse files
author
David Grieve
authored
Merge pull request #215 from milderhc/jdbc-vector-search
Add vector search for JDBC in general
2 parents ecae78f + 470f6c1 commit 0851e0d

26 files changed

Lines changed: 691 additions & 314 deletions

api-test/integration-tests/src/test/java/com/microsoft/semantickernel/tests/connectors/memory/Hotel.java

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,24 +11,37 @@
1111
public class Hotel {
1212
@VectorStoreRecordKeyAttribute
1313
private final String id;
14+
1415
@VectorStoreRecordDataAttribute
1516
private final String name;
17+
1618
@VectorStoreRecordDataAttribute
1719
private final int code;
20+
1821
@JsonProperty("summary")
1922
@VectorStoreRecordDataAttribute()
2023
private final String description;
21-
@JsonProperty("summaryEmbedding")
22-
@VectorStoreRecordVectorAttribute(dimensions = 8)
23-
private final List<Float> descriptionEmbedding;
24+
25+
@JsonProperty("summaryEmbedding1")
26+
@VectorStoreRecordVectorAttribute(dimensions = 8, distanceFunction = "euclidean")
27+
private final List<Float> euclidean;
28+
29+
@JsonProperty("summaryEmbedding2")
30+
@VectorStoreRecordVectorAttribute(dimensions = 8, distanceFunction = "cosineDistance")
31+
private final List<Float> cosineDistance;
32+
33+
@JsonProperty("summaryEmbedding3")
34+
@VectorStoreRecordVectorAttribute(dimensions = 8, distanceFunction = "dotProduct")
35+
private final List<Float> dotProduct;
36+
2437
@JsonProperty("indexedSummaryEmbedding")
25-
@VectorStoreRecordVectorAttribute(dimensions = 8, indexKind = "hnsw", distanceFunction = "cosine")
26-
private final List<Float> indexedDescriptionEmbedding;
38+
@VectorStoreRecordVectorAttribute(dimensions = 8, indexKind = "hnsw", distanceFunction = "euclidean")
39+
private final List<Float> indexedEuclidean;
2740
@VectorStoreRecordDataAttribute
2841
private double rating;
2942

3043
public Hotel() {
31-
this(null, null, 0, null, null, null, 0.0);
44+
this(null, null, 0, null, null, null, null, null, 0.0);
3245
}
3346

3447
@JsonCreator
@@ -37,15 +50,19 @@ public Hotel(
3750
@JsonProperty("name") String name,
3851
@JsonProperty("code") int code,
3952
@JsonProperty("summary") String description,
40-
@JsonProperty("summaryEmbedding") List<Float> descriptionEmbedding,
41-
@JsonProperty("indexedSummaryEmbedding") List<Float> indexedDescriptionEmbedding,
53+
@JsonProperty("summaryEmbedding1") List<Float> euclidean,
54+
@JsonProperty("summaryEmbedding2") List<Float> cosineDistance,
55+
@JsonProperty("summaryEmbedding3") List<Float> dotProduct,
56+
@JsonProperty("indexedSummaryEmbedding") List<Float> indexedEuclidean,
4257
@JsonProperty("rating") double rating) {
4358
this.id = id;
4459
this.name = name;
4560
this.code = code;
4661
this.description = description;
47-
this.descriptionEmbedding = descriptionEmbedding;
48-
this.indexedDescriptionEmbedding = indexedDescriptionEmbedding;
62+
this.euclidean = euclidean;
63+
this.cosineDistance = euclidean;
64+
this.dotProduct = euclidean;
65+
this.indexedEuclidean = euclidean;
4966
this.rating = rating;
5067
}
5168

@@ -65,11 +82,12 @@ public String getDescription() {
6582
return description;
6683
}
6784

68-
public List<Float> getDescriptionEmbedding() {
69-
return descriptionEmbedding;
85+
public List<Float> getEuclidean() {
86+
return euclidean;
7087
}
71-
public List<Float> getIndexedDescriptionEmbedding() {
72-
return indexedDescriptionEmbedding;
88+
89+
public List<Float> getIndexedEuclidean() {
90+
return indexedEuclidean;
7391
}
7492

7593
public double getRating() {

api-test/integration-tests/src/test/java/com/microsoft/semantickernel/tests/connectors/memory/jdbc/JDBCVectorStoreRecordCollectionTest.java

Lines changed: 87 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,18 @@
55
import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreRecordCollectionOptions;
66
import com.microsoft.semantickernel.connectors.data.mysql.MySQLVectorStoreQueryProvider;
77
import com.microsoft.semantickernel.connectors.data.postgres.PostgreSQLVectorStoreQueryProvider;
8+
import com.microsoft.semantickernel.connectors.data.jdbc.filter.SQLEqualToFilterClause;
9+
import com.microsoft.semantickernel.data.vectorsearch.VectorSearchFilter;
810
import com.microsoft.semantickernel.data.vectorsearch.VectorSearchResult;
911
import com.microsoft.semantickernel.data.vectorstorage.options.GetRecordOptions;
1012
import com.microsoft.semantickernel.data.vectorstorage.options.VectorSearchOptions;
1113
import com.microsoft.semantickernel.tests.connectors.memory.Hotel;
1214
import com.mysql.cj.jdbc.MysqlDataSource;
1315
import org.junit.jupiter.api.Test;
1416
import org.junit.jupiter.params.ParameterizedTest;
17+
import org.junit.jupiter.params.provider.Arguments;
1518
import org.junit.jupiter.params.provider.EnumSource;
19+
import org.junit.jupiter.params.provider.MethodSource;
1620
import org.postgresql.ds.PGSimpleDataSource;
1721
import org.testcontainers.containers.MySQLContainer;
1822
import org.testcontainers.containers.PostgreSQLContainer;
@@ -25,6 +29,7 @@
2529
import java.util.ArrayList;
2630
import java.util.Arrays;
2731
import java.util.List;
32+
import java.util.stream.Stream;
2833

2934
import static org.junit.jupiter.api.Assertions.assertEquals;
3035
import static org.junit.jupiter.api.Assertions.assertNotNull;
@@ -99,19 +104,23 @@ private List<Hotel> getHotels() {
99104
ArrayList<Hotel> embeddings = new ArrayList<>();
100105

101106
return List.of(
102-
new Hotel("id_1", "Hotel 1", 1, "Hotel 1 description", Arrays.asList(0.5f, 3.2f, 7.1f, -4.0f, 2.8f, 10.0f, -1.3f, 5.5f),
103-
Arrays.asList(0.5f, 3.2f, 7.1f, -4.0f, 2.8f, 10.0f, -1.3f, 5.5f),4.0),
104-
new Hotel("id_2", "Hotel 2", 2, "Hotel 2 description", Arrays.asList(-2.0f, 8.1f, 0.9f, 5.4f, -3.3f, 2.2f, 9.9f, -4.5f),
105-
Arrays.asList(-2.0f, 8.1f, 0.9f, 5.4f, -3.3f, 2.2f, 9.9f, -4.5f),3.0),
106-
new Hotel("id_3", "Hotel 3", 3, "Hotel 3 description", Arrays.asList(4.5f, -6.2f, 3.1f, 7.7f, -0.8f, 1.1f, -2.2f, 8.3f),
107-
Arrays.asList(4.5f, -6.2f, 3.1f, 7.7f, -0.8f, 1.1f, -2.2f, 8.3f),5.0),
108-
new Hotel("id_4", "Hotel 4", 4, "Hotel 4 description", Arrays.asList(7.0f, 1.2f, -5.3f, 2.5f, 6.6f, -7.8f, 3.9f, -0.1f),
109-
Arrays.asList(7.0f, 1.2f, -5.3f, 2.5f, 6.6f, -7.8f, 3.9f, -0.1f),4.0),
110-
new Hotel("id_5", "Hotel 5", 5, "Hotel 5 description", Arrays.asList(-3.5f, 4.4f, -1.2f, 9.9f, 5.7f, -6.1f, 7.8f, -2.0f),
111-
Arrays.asList(-3.5f, 4.4f, -1.2f, 9.9f, 5.7f, -6.1f, 7.8f, -2.0f),5.0)
107+
new Hotel("id_1", "Hotel 1", 1, "Hotel 1 description", Arrays.asList(0.5f, 3.2f, 7.1f, -4.0f, 2.8f, 10.0f, -1.3f, 5.5f),null, null, null, 4.0),
108+
new Hotel("id_2", "Hotel 2", 2, "Hotel 2 description", Arrays.asList(-2.0f, 8.1f, 0.9f, 5.4f, -3.3f, 2.2f, 9.9f, -4.5f),null, null, null, 4.0),
109+
new Hotel("id_3", "Hotel 3", 3, "Hotel 3 description", Arrays.asList(4.5f, -6.2f, 3.1f, 7.7f, -0.8f, 1.1f, -2.2f, 8.3f),null, null, null, 5.0),
110+
new Hotel("id_4", "Hotel 4", 4, "Hotel 4 description", Arrays.asList(7.0f, 1.2f, -5.3f, 2.5f, 6.6f, -7.8f, 3.9f, -0.1f),null, null, null, 4.0),
111+
new Hotel("id_5", "Hotel 5", 5, "Hotel 5 description", Arrays.asList(-3.5f, 4.4f, -1.2f, 9.9f, 5.7f, -6.1f, 7.8f, -2.0f),null, null, null, 4.0)
112112
);
113113
}
114114

115+
/**
116+
* Search embeddings similar to the third hotel embeddings.
117+
* In order of similarity:
118+
* 1. Hotel 3
119+
* 2. Hotel 1
120+
* 3. Hotel 4
121+
*/
122+
private static final List<Float> SEARCH_EMBEDDINGS = Arrays.asList(4.5f, -6.2f, 3.1f, 7.7f, -0.8f, 1.1f, -2.2f, 8.2f);
123+
115124
@ParameterizedTest
116125
@EnumSource(QueryProvider.class)
117126
public void upsertAndGetRecordAsync(QueryProvider provider) {
@@ -263,7 +272,7 @@ public void getWithNoVectors(QueryProvider provider) {
263272
Hotel retrievedHotel = recordCollection.getAsync(hotel.getId(), options).block();
264273
assertNotNull(retrievedHotel);
265274
assertEquals(hotel.getId(), retrievedHotel.getId());
266-
assertNull(retrievedHotel.getDescriptionEmbedding());
275+
assertNull(retrievedHotel.getEuclidean());
267276
}
268277

269278
options = GetRecordOptions.builder()
@@ -274,7 +283,7 @@ public void getWithNoVectors(QueryProvider provider) {
274283
Hotel retrievedHotel = recordCollection.getAsync(hotel.getId(), options).block();
275284
assertNotNull(retrievedHotel);
276285
assertEquals(hotel.getId(), retrievedHotel.getId());
277-
assertNotNull(retrievedHotel.getDescriptionEmbedding());
286+
assertNotNull(retrievedHotel.getEuclidean());
278287
}
279288
}
280289

@@ -301,7 +310,7 @@ public void getBatchWithNoVectors(QueryProvider provider) {
301310
assertEquals(hotels.size(), retrievedHotels.size());
302311

303312
for (Hotel hotel : retrievedHotels) {
304-
assertNull(hotel.getDescriptionEmbedding());
313+
assertNull(hotel.getEuclidean());
305314
}
306315

307316
options = GetRecordOptions.builder()
@@ -313,83 +322,127 @@ public void getBatchWithNoVectors(QueryProvider provider) {
313322
assertEquals(hotels.size(), retrievedHotels.size());
314323

315324
for (Hotel hotel : retrievedHotels) {
316-
assertNotNull(hotel.getDescriptionEmbedding());
325+
assertNotNull(hotel.getEuclidean());
317326
}
318327
}
319328

329+
private static Stream<Arguments> provideSearchParameters() {
330+
return Stream.of(
331+
Arguments.of(QueryProvider.MySQL, "euclidean"),
332+
Arguments.of(QueryProvider.MySQL, "cosineDistance"),
333+
Arguments.of(QueryProvider.MySQL, "dotProduct"),
334+
Arguments.of(QueryProvider.PostgreSQL, "euclidean"),
335+
Arguments.of(QueryProvider.PostgreSQL, "cosineDistance"),
336+
Arguments.of(QueryProvider.PostgreSQL, "dotProduct")
337+
);
338+
}
320339

321-
@Test
322-
public void postgresExactSearch() {
323-
String collectionName = "search";
324-
JDBCVectorStoreRecordCollection<Hotel> recordCollection = buildRecordCollection(QueryProvider.PostgreSQL, collectionName);
340+
@ParameterizedTest
341+
@MethodSource("provideSearchParameters")
342+
public void exactSearch(QueryProvider provider, String embeddingName) {
343+
String collectionName = "search" + embeddingName;
344+
JDBCVectorStoreRecordCollection<Hotel> recordCollection = buildRecordCollection(provider, collectionName);
325345

326346
List<Hotel> hotels = getHotels();
327347
recordCollection.upsertBatchAsync(hotels, null).block();
328348

349+
VectorSearchOptions options = VectorSearchOptions.builder()
350+
.withVectorFieldName(embeddingName)
351+
.withLimit(3)
352+
.build();
353+
329354
// Embeddings similar to the third hotel
330-
List<Float> embeddings = Arrays.asList(4.5f, -6.2f, 3.1f, 7.7f, -0.8f, 1.1f, -2.2f, 8.2f);
331-
List<VectorSearchResult<Hotel>> results = recordCollection.searchAsync(embeddings, null).block();
355+
List<VectorSearchResult<Hotel>> results = recordCollection.searchAsync(SEARCH_EMBEDDINGS, options).block();
332356
assertNotNull(results);
333357
assertEquals(3, results.size());
334358
// The third hotel should be the most similar
335359
assertEquals(hotels.get(2).getId(), results.get(0).getRecord().getId());
336360

361+
options = VectorSearchOptions.builder()
362+
.withVectorFieldName(embeddingName)
363+
.withOffset(1)
364+
.withLimit(-100)
365+
.build();
366+
337367
// Skip the first result
338-
results = recordCollection.searchAsync(embeddings, VectorSearchOptions.builder().withOffset(1).withLimit(-100).build()).block();
368+
results = recordCollection.searchAsync(SEARCH_EMBEDDINGS, options).block();
339369
assertNotNull(results);
340370
assertEquals(1, results.size());
341371
// The first hotel should be the most similar
342372
assertEquals(hotels.get(0).getId(), results.get(0).getRecord().getId());
343373
}
344374

345-
@Test
346-
public void postgresApproximateSearch() {
375+
@ParameterizedTest
376+
@EnumSource(QueryProvider.class)
377+
public void approximateSearch(QueryProvider provider) {
347378
String collectionName = "searchWithIndex";
348-
JDBCVectorStoreRecordCollection<Hotel> recordCollection = buildRecordCollection(QueryProvider.PostgreSQL, collectionName);
379+
JDBCVectorStoreRecordCollection<Hotel> recordCollection = buildRecordCollection(provider, collectionName);
349380

350381
List<Hotel> hotels = getHotels();
351382
recordCollection.upsertBatchAsync(hotels, null).block();
352383

353384
VectorSearchOptions options = VectorSearchOptions.builder()
354-
.withVectorFieldName("indexedDescriptionEmbedding")
385+
.withVectorFieldName("indexedEuclidean")
355386
.withLimit(5)
356387
.build();
357388

358389
// Embeddings similar to the third hotel
359-
List<Float> embeddings = Arrays.asList(4.5f, -6.2f, 3.1f, 7.7f, -0.8f, 1.1f, -2.2f, 8.2f);
360-
List<VectorSearchResult<Hotel>> results = recordCollection.searchAsync(embeddings, options).block();
390+
List<VectorSearchResult<Hotel>> results = recordCollection.searchAsync(SEARCH_EMBEDDINGS, options).block();
361391
assertNotNull(results);
362392
assertEquals(5, results.size());
363393
// The third hotel should be the most similar
364394
assertEquals(hotels.get(2).getId(), results.get(0).getRecord().getId());
365395
}
366396

397+
@ParameterizedTest
398+
@MethodSource("provideSearchParameters")
399+
public void searchWithFilter(QueryProvider provider, String embeddingName) {
400+
String collectionName = "searchWithFilter";
401+
JDBCVectorStoreRecordCollection<Hotel> recordCollection = buildRecordCollection(provider, collectionName);
402+
403+
List<Hotel> hotels = getHotels();
404+
recordCollection.upsertBatchAsync(hotels, null).block();
405+
406+
VectorSearchOptions options = VectorSearchOptions.builder()
407+
.withVectorFieldName(embeddingName)
408+
.withLimit(3)
409+
.withVectorSearchFilter(
410+
VectorSearchFilter.builder().withEqualToFilterClause(new SQLEqualToFilterClause("rating", 4.0)).build())
411+
.build();
412+
413+
// Embeddings similar to the third hotel, but as the filter is set to 4.0, the third hotel should not be returned
414+
List<VectorSearchResult<Hotel>> results = recordCollection.searchAsync(SEARCH_EMBEDDINGS, options).block();
415+
assertNotNull(results);
416+
assertEquals(3, results.size());
417+
// The first hotel should be the most similar
418+
assertEquals(hotels.get(0).getId(), results.get(0).getRecord().getId());
419+
}
420+
421+
// MySQL will always return the vectors as they're needed to compute the distances
367422
@Test
368-
public void searchIncludeAndNotIncludeVectors() {
423+
public void postgresSearchIncludeAndNotIncludeVectors() {
369424
String collectionName = "searchIncludeAndNotIncludeVectors";
370425
JDBCVectorStoreRecordCollection<Hotel> recordCollection = buildRecordCollection(QueryProvider.PostgreSQL, collectionName);
371426

372427
List<Hotel> hotels = getHotels();
373428
recordCollection.upsertBatchAsync(hotels, null).block();
374429

375-
// Embeddings similar to the third hotel
376-
List<Float> embeddings = Arrays.asList(4.5f, -6.2f, 3.1f, 7.7f, -0.8f, 1.1f, -2.2f, 8.2f);
377-
List<VectorSearchResult<Hotel>> results = recordCollection.searchAsync(embeddings, null).block();
430+
List<VectorSearchResult<Hotel>> results = recordCollection.searchAsync(SEARCH_EMBEDDINGS, null).block();
378431
assertNotNull(results);
379432
assertEquals(3, results.size());
380433
// The third hotel should be the most similar
381434
assertEquals(hotels.get(2).getId(), results.get(0).getRecord().getId());
382-
assertNull(results.get(0).getRecord().getDescriptionEmbedding());
435+
assertNull(results.get(0).getRecord().getEuclidean());
383436

384437
VectorSearchOptions options = VectorSearchOptions.builder()
385438
.withIncludeVectors(true)
386439
.build();
387440

388-
results = recordCollection.searchAsync(embeddings, options).block();
441+
results = recordCollection.searchAsync(SEARCH_EMBEDDINGS, options).block();
389442
assertNotNull(results);
390443
assertEquals(3, results.size());
391444
// The third hotel should be the most similar
392445
assertEquals(hotels.get(2).getId(), results.get(0).getRecord().getId());
393-
assertNotNull(results.get(0).getRecord().getDescriptionEmbedding());
446+
assertNotNull(results.get(0).getRecord().getEuclidean());
394447
}
395448
}

0 commit comments

Comments
 (0)