Skip to content

Commit c4cab96

Browse files
author
Milder Hernandez Cagua
committed
Add JDBC and Postgres filtering
1 parent b0dbb37 commit c4cab96

15 files changed

Lines changed: 280 additions & 106 deletions

File tree

api-test/integration-tests/src/test/java/com/microsoft/semantickernel/tests/connectors/memory/jdbc/JDBCVectorStoreRecordCollectionTest.java

Lines changed: 46 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreRecordCollectionOptions;
66
import com.microsoft.semantickernel.connectors.data.mysql.MySQLVectorStoreQueryProvider;
77
import com.microsoft.semantickernel.connectors.data.postgres.PostgreSQLVectorStoreQueryProvider;
8+
import com.microsoft.semantickernel.connectors.data.jdbc.filter.SQLEqualToFilterClause;
9+
import com.microsoft.semantickernel.data.vectorsearch.VectorSearchFilter;
810
import com.microsoft.semantickernel.data.vectorsearch.VectorSearchResult;
911
import com.microsoft.semantickernel.data.vectorstorage.options.GetRecordOptions;
1012
import com.microsoft.semantickernel.data.vectorstorage.options.VectorSearchOptions;
@@ -103,13 +105,22 @@ private List<Hotel> getHotels() {
103105

104106
return List.of(
105107
new Hotel("id_1", "Hotel 1", 1, "Hotel 1 description", Arrays.asList(0.5f, 3.2f, 7.1f, -4.0f, 2.8f, 10.0f, -1.3f, 5.5f),null, null, null, 4.0),
106-
new Hotel("id_2", "Hotel 2", 2, "Hotel 2 description", Arrays.asList(-2.0f, 8.1f, 0.9f, 5.4f, -3.3f, 2.2f, 9.9f, -4.5f),null, null, null, 3.0),
108+
new Hotel("id_2", "Hotel 2", 2, "Hotel 2 description", Arrays.asList(-2.0f, 8.1f, 0.9f, 5.4f, -3.3f, 2.2f, 9.9f, -4.5f),null, null, null, 4.0),
107109
new Hotel("id_3", "Hotel 3", 3, "Hotel 3 description", Arrays.asList(4.5f, -6.2f, 3.1f, 7.7f, -0.8f, 1.1f, -2.2f, 8.3f),null, null, null, 5.0),
108-
new Hotel("id_4", "Hotel 4", 4, "Hotel 4 description", Arrays.asList(7.0f, 1.2f, -5.3f, 2.5f, 6.6f, -7.8f, 3.9f, -0.1f),null, null, null, 2.0),
109-
new Hotel("id_5", "Hotel 5", 5, "Hotel 5 description", Arrays.asList(-3.5f, 4.4f, -1.2f, 9.9f, 5.7f, -6.1f, 7.8f, -2.0f),null, null, null, 1.0)
110+
new Hotel("id_4", "Hotel 4", 4, "Hotel 4 description", Arrays.asList(7.0f, 1.2f, -5.3f, 2.5f, 6.6f, -7.8f, 3.9f, -0.1f),null, null, null, 4.0),
111+
new Hotel("id_5", "Hotel 5", 5, "Hotel 5 description", Arrays.asList(-3.5f, 4.4f, -1.2f, 9.9f, 5.7f, -6.1f, 7.8f, -2.0f),null, null, null, 4.0)
110112
);
111113
}
112114

115+
/**
116+
* Search embeddings similar to the third hotel embeddings.
117+
* In order of similarity:
118+
* 1. Hotel 3
119+
* 2. Hotel 1
120+
* 3. Hotel 4
121+
*/
122+
private static final List<Float> SEARCH_EMBEDDINGS = Arrays.asList(4.5f, -6.2f, 3.1f, 7.7f, -0.8f, 1.1f, -2.2f, 8.2f);
123+
113124
@ParameterizedTest
114125
@EnumSource(QueryProvider.class)
115126
public void upsertAndGetRecordAsync(QueryProvider provider) {
@@ -315,7 +326,7 @@ public void getBatchWithNoVectors(QueryProvider provider) {
315326
}
316327
}
317328

318-
private static Stream<Arguments> provideParameters() {
329+
private static Stream<Arguments> provideSearchParameters() {
319330
return Stream.of(
320331
Arguments.of(QueryProvider.MySQL, "euclidean"),
321332
Arguments.of(QueryProvider.MySQL, "cosineDistance"),
@@ -327,9 +338,9 @@ private static Stream<Arguments> provideParameters() {
327338
}
328339

329340
@ParameterizedTest
330-
@MethodSource("provideParameters")
341+
@MethodSource("provideSearchParameters")
331342
public void exactSearch(QueryProvider provider, String embeddingName) {
332-
String collectionName = "search";
343+
String collectionName = "search" + embeddingName;
333344
JDBCVectorStoreRecordCollection<Hotel> recordCollection = buildRecordCollection(provider, collectionName);
334345

335346
List<Hotel> hotels = getHotels();
@@ -341,8 +352,7 @@ public void exactSearch(QueryProvider provider, String embeddingName) {
341352
.build();
342353

343354
// Embeddings similar to the third hotel
344-
List<Float> embeddings = Arrays.asList(4.5f, -6.2f, 3.1f, 7.7f, -0.8f, 1.1f, -2.2f, 8.2f);
345-
List<VectorSearchResult<Hotel>> results = recordCollection.searchAsync(embeddings, options).block();
355+
List<VectorSearchResult<Hotel>> results = recordCollection.searchAsync(SEARCH_EMBEDDINGS, options).block();
346356
assertNotNull(results);
347357
assertEquals(3, results.size());
348358
// The third hotel should be the most similar
@@ -355,7 +365,7 @@ public void exactSearch(QueryProvider provider, String embeddingName) {
355365
.build();
356366

357367
// Skip the first result
358-
results = recordCollection.searchAsync(embeddings, options).block();
368+
results = recordCollection.searchAsync(SEARCH_EMBEDDINGS, options).block();
359369
assertNotNull(results);
360370
assertEquals(1, results.size());
361371
// The first hotel should be the most similar
@@ -377,14 +387,37 @@ public void approximateSearch(QueryProvider provider) {
377387
.build();
378388

379389
// Embeddings similar to the third hotel
380-
List<Float> embeddings = Arrays.asList(4.5f, -6.2f, 3.1f, 7.7f, -0.8f, 1.1f, -2.2f, 8.2f);
381-
List<VectorSearchResult<Hotel>> results = recordCollection.searchAsync(embeddings, options).block();
390+
List<VectorSearchResult<Hotel>> results = recordCollection.searchAsync(SEARCH_EMBEDDINGS, options).block();
382391
assertNotNull(results);
383392
assertEquals(5, results.size());
384393
// The third hotel should be the most similar
385394
assertEquals(hotels.get(2).getId(), results.get(0).getRecord().getId());
386395
}
387396

397+
@ParameterizedTest
398+
@MethodSource("provideSearchParameters")
399+
public void searchWithFilter(QueryProvider provider, String embeddingName) {
400+
String collectionName = "searchWithFilter";
401+
JDBCVectorStoreRecordCollection<Hotel> recordCollection = buildRecordCollection(provider, collectionName);
402+
403+
List<Hotel> hotels = getHotels();
404+
recordCollection.upsertBatchAsync(hotels, null).block();
405+
406+
VectorSearchOptions options = VectorSearchOptions.builder()
407+
.withVectorFieldName(embeddingName)
408+
.withLimit(3)
409+
.withBasicVectorSearchFilter(
410+
VectorSearchFilter.builder().withEqualToFilterClause(new SQLEqualToFilterClause("rating", 4.0)).build())
411+
.build();
412+
413+
// Embeddings similar to the third hotel, but as the filter is set to 4.0, the third hotel should not be returned
414+
List<VectorSearchResult<Hotel>> results = recordCollection.searchAsync(SEARCH_EMBEDDINGS, options).block();
415+
assertNotNull(results);
416+
assertEquals(3, results.size());
417+
// The first hotel should be the most similar
418+
assertEquals(hotels.get(0).getId(), results.get(0).getRecord().getId());
419+
}
420+
388421
// MySQL will always return the vectors as they're needed to compute the distances
389422
@Test
390423
public void postgresSearchIncludeAndNotIncludeVectors() {
@@ -394,9 +427,7 @@ public void postgresSearchIncludeAndNotIncludeVectors() {
394427
List<Hotel> hotels = getHotels();
395428
recordCollection.upsertBatchAsync(hotels, null).block();
396429

397-
// Embeddings similar to the third hotel
398-
List<Float> embeddings = Arrays.asList(4.5f, -6.2f, 3.1f, 7.7f, -0.8f, 1.1f, -2.2f, 8.2f);
399-
List<VectorSearchResult<Hotel>> results = recordCollection.searchAsync(embeddings, null).block();
430+
List<VectorSearchResult<Hotel>> results = recordCollection.searchAsync(SEARCH_EMBEDDINGS, null).block();
400431
assertNotNull(results);
401432
assertEquals(3, results.size());
402433
// The third hotel should be the most similar
@@ -407,7 +438,7 @@ public void postgresSearchIncludeAndNotIncludeVectors() {
407438
.withIncludeVectors(true)
408439
.build();
409440

410-
results = recordCollection.searchAsync(embeddings, options).block();
441+
results = recordCollection.searchAsync(SEARCH_EMBEDDINGS, options).block();
411442
assertNotNull(results);
412443
assertEquals(3, results.size());
413444
// The third hotel should be the most similar

semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/azureaisearch/AzureAISearchVectorStoreCollectionSearchMapping.java

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,40 @@
11
// Copyright (c) Microsoft. All rights reserved.
22
package com.microsoft.semantickernel.connectors.data.azureaisearch;
33

4-
import com.microsoft.semantickernel.data.filtering.BasicVectorSearchFilter;
4+
import com.microsoft.semantickernel.connectors.data.azureaisearch.filter.AzureAISearchEqualToFilterClause;
5+
import com.microsoft.semantickernel.connectors.data.azureaisearch.filter.AzureAISearchAnyTagEqualToFilterClause;
6+
import com.microsoft.semantickernel.data.vectorsearch.VectorSearchFilter;
7+
import com.microsoft.semantickernel.data.vectorstorage.definition.VectorStoreRecordDefinition;
58
import com.microsoft.semantickernel.exceptions.SKException;
69

7-
import java.util.Map;
810
import java.util.stream.Collectors;
911

1012
class AzureAISearchVectorStoreCollectionSearchMapping {
11-
public static String buildFilterString(BasicVectorSearchFilter basicVectorSearchFilter,
12-
Map<String, String> storageNames) {
13-
if (basicVectorSearchFilter == null
14-
|| basicVectorSearchFilter.getFilterClauses().isEmpty()) {
13+
public static String buildFilterString(VectorSearchFilter vectorSearchFilter,
14+
VectorStoreRecordDefinition recordDefinition) {
15+
if (vectorSearchFilter == null
16+
|| vectorSearchFilter.getFilterClauses().isEmpty()) {
1517
return "";
1618
}
1719

1820
return String.join(" and ",
19-
basicVectorSearchFilter.getFilterClauses().stream().map(filterClause -> {
20-
if (filterClause instanceof AzureAISearchEqualityFilterClause) {
21-
AzureAISearchEqualityFilterClause azureFilterClause = (AzureAISearchEqualityFilterClause) filterClause;
21+
vectorSearchFilter.getFilterClauses().stream().map(filterClause -> {
22+
if (filterClause instanceof AzureAISearchEqualToFilterClause) {
23+
AzureAISearchEqualToFilterClause azureFilterClause = (AzureAISearchEqualToFilterClause) filterClause;
2224
// Create new instance with the storage name of the field
23-
return new AzureAISearchEqualityFilterClause(
24-
storageNames.get(azureFilterClause.getFieldName()),
25-
azureFilterClause.getValue()).getFilter();
26-
} else if (filterClause instanceof AzureAISearchTagListContainsFilterClause) {
27-
AzureAISearchTagListContainsFilterClause azureFilterClause = (AzureAISearchTagListContainsFilterClause) filterClause;
25+
return new AzureAISearchEqualToFilterClause(
26+
recordDefinition.getField(azureFilterClause.getFieldName())
27+
.getEffectiveStorageName(),
28+
azureFilterClause.getValue())
29+
.getFilter();
30+
} else if (filterClause instanceof AzureAISearchAnyTagEqualToFilterClause) {
31+
AzureAISearchAnyTagEqualToFilterClause azureFilterClause = (AzureAISearchAnyTagEqualToFilterClause) filterClause;
2832
// Create new instance with the storage name of the field
29-
return new AzureAISearchTagListContainsFilterClause(
30-
storageNames.get(azureFilterClause.getFieldName()),
31-
azureFilterClause.getValue()).getFilter();
33+
return new AzureAISearchAnyTagEqualToFilterClause(
34+
recordDefinition.getField(azureFilterClause.getFieldName())
35+
.getEffectiveStorageName(),
36+
azureFilterClause.getValue())
37+
.getFilter();
3238
} else {
3339
throw new SKException("Unsupported filter clause type '"
3440
+ filterClause.getClass().getSimpleName() + "'.");

semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/azureaisearch/AzureAISearchVectorStoreRecordCollection.java

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,6 @@ public class AzureAISearchVectorStoreRecordCollection<Record> implements
8484

8585
// List of non-vector fields. Used to fetch only non-vector fields when vectors are not requested
8686
private final List<String> nonVectorFields = new ArrayList<>();
87-
private final Map<String, String> storageNames;
8887
private final String firstVectorFieldName;
8988

9089
@SuppressFBWarnings("EI_EXPOSE_REP2")
@@ -119,7 +118,6 @@ public AzureAISearchVectorStoreRecordCollection(
119118
.map(VectorStoreRecordDataField::getEffectiveStorageName)
120119
.collect(Collectors.toList()));
121120

122-
storageNames = recordDefinition.getFieldStorageNames();
123121
firstVectorFieldName = recordDefinition.getVectorFields().isEmpty() ? null
124122
: recordDefinition.getVectorFields().get(0).getName();
125123
}
@@ -323,25 +321,23 @@ public Mono<List<VectorSearchResult<Record>>> searchAsync(VectorSearchQuery quer
323321

324322
if (query instanceof VectorizedSearchQuery) {
325323
vectorQueries.add(new VectorizedQuery(((VectorizedSearchQuery) query).getVector())
326-
.setFields(
327-
storageNames
328-
.get(options.getVectorFieldName() != null ? options.getVectorFieldName()
329-
: firstVectorFieldName))
324+
.setFields(recordDefinition.getField(options.getVectorFieldName() != null
325+
? options.getVectorFieldName()
326+
: firstVectorFieldName).getEffectiveStorageName())
330327
.setKNearestNeighborsCount(options.getLimit()));
331328
} else if (query instanceof VectorizableTextSearchQuery) {
332329
vectorQueries
333330
.add(new VectorizableTextQuery(((VectorizableTextSearchQuery) query).getQueryText())
334-
.setFields(
335-
storageNames
336-
.get(options.getVectorFieldName() != null ? options.getVectorFieldName()
337-
: firstVectorFieldName))
331+
.setFields(recordDefinition.getField(options.getVectorFieldName() != null
332+
? options.getVectorFieldName()
333+
: firstVectorFieldName).getEffectiveStorageName())
338334
.setKNearestNeighborsCount(options.getLimit()));
339335
} else {
340336
throw new SKException("Unsupported query type: " + query.getQueryType());
341337
}
342338

343339
String filter = AzureAISearchVectorStoreCollectionSearchMapping
344-
.buildFilterString(options.getBasicVectorSearchFilter(), storageNames);
340+
.buildFilterString(options.getBasicVectorSearchFilter(), recordDefinition);
345341

346342
SearchOptions searchOptions = new SearchOptions()
347343
.setFilter(filter)

semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/azureaisearch/AzureAISearchTagListContainsFilterClause.java renamed to semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/azureaisearch/filter/AzureAISearchAnyTagEqualToFilterClause.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
11
// Copyright (c) Microsoft. All rights reserved.
2-
package com.microsoft.semantickernel.connectors.data.azureaisearch;
2+
package com.microsoft.semantickernel.connectors.data.azureaisearch.filter;
33

4-
import com.microsoft.semantickernel.data.filtering.TagListContainsFilterClause;
4+
import com.microsoft.semantickernel.data.filter.AnyTagEqualToFilterClause;
55

6-
public class AzureAISearchTagListContainsFilterClause extends TagListContainsFilterClause {
6+
public class AzureAISearchAnyTagEqualToFilterClause extends AnyTagEqualToFilterClause {
77

88
/**
99
* Initializes a new instance of the AzureAISearchTagListContainsFilterClause class.
1010
*
1111
* @param fieldName The field name to filter on.
1212
* @param value The value.
1313
*/
14-
public AzureAISearchTagListContainsFilterClause(String fieldName, Object value) {
14+
public AzureAISearchAnyTagEqualToFilterClause(String fieldName, Object value) {
1515
super(fieldName, value);
1616
}
1717

semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/azureaisearch/AzureAISearchEqualityFilterClause.java renamed to semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/azureaisearch/filter/AzureAISearchEqualToFilterClause.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,21 @@
11
// Copyright (c) Microsoft. All rights reserved.
2-
package com.microsoft.semantickernel.connectors.data.azureaisearch;
2+
package com.microsoft.semantickernel.connectors.data.azureaisearch.filter;
33

4-
import com.microsoft.semantickernel.data.filtering.EqualityFilterClause;
4+
import com.microsoft.semantickernel.data.filter.EqualToFilterClause;
55
import com.microsoft.semantickernel.exceptions.SKException;
66

77
import java.time.OffsetDateTime;
88
import java.time.format.DateTimeFormatter;
99

10-
public class AzureAISearchEqualityFilterClause extends EqualityFilterClause {
10+
public class AzureAISearchEqualToFilterClause extends EqualToFilterClause {
1111

1212
/**
1313
* Initializes a new instance of the AzureAISearchEqualityFilterClause class.
1414
*
1515
* @param fieldName The field name to filter on.
1616
* @param value The value.
1717
*/
18-
public AzureAISearchEqualityFilterClause(String fieldName, Object value) {
18+
public AzureAISearchEqualToFilterClause(String fieldName, Object value) {
1919
super(fieldName, value);
2020
}
2121

semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreQueryProvider.java

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import com.fasterxml.jackson.core.JsonProcessingException;
55
import com.fasterxml.jackson.databind.ObjectMapper;
66
import com.fasterxml.jackson.databind.node.ArrayNode;
7+
import com.microsoft.semantickernel.data.filter.FilterClause;
78
import com.microsoft.semantickernel.data.vectorsearch.VectorOperations;
89
import com.microsoft.semantickernel.data.vectorsearch.VectorSearchResult;
910
import com.microsoft.semantickernel.data.vectorsearch.queries.VectorSearchQuery;
@@ -427,22 +428,31 @@ public void deleteRecords(String collectionName, List<String> keys,
427428
}
428429
}
429430

430-
protected <Record> List<Record> getAllRecords(String collectionName,
431+
protected <Record> List<Record> getRecordsWithFilter(String collectionName,
431432
VectorStoreRecordDefinition recordDefinition,
432-
VectorStoreRecordMapper<Record, ResultSet> mapper, GetRecordOptions options) {
433+
VectorStoreRecordMapper<Record, ResultSet> mapper, GetRecordOptions options, String filter,
434+
List<Object> parameters) {
433435
List<VectorStoreRecordField> fields;
434436
if (options.isIncludeVectors()) {
435437
fields = recordDefinition.getAllFields();
436438
} else {
437439
fields = recordDefinition.getNonVectorFields();
438440
}
439441

440-
String selectQuery = formatQuery("SELECT %s FROM %s",
442+
String filterClause = filter == null || filter.isEmpty() ? "" : "WHERE " + filter;
443+
String selectQuery = formatQuery("SELECT %s FROM %s %s",
441444
getQueryColumnsFromFields(fields),
442-
getCollectionTableName(collectionName));
445+
getCollectionTableName(collectionName),
446+
filterClause);
443447

444448
try (Connection connection = dataSource.getConnection();
445449
PreparedStatement statement = connection.prepareStatement(selectQuery)) {
450+
if (parameters != null) {
451+
for (int i = 0; i < parameters.size(); ++i) {
452+
statement.setObject(i + 1, parameters.get(i));
453+
}
454+
}
455+
446456
List<Record> records = new ArrayList<>();
447457
ResultSet resultSet = statement.executeQuery();
448458
while (resultSet.next()) {
@@ -491,8 +501,13 @@ public <Record> List<VectorSearchResult<Record>> search(String collectionName,
491501
: (VectorStoreRecordVectorField) recordDefinition
492502
.getField(options.getVectorFieldName());
493503

494-
List<Record> records = getAllRecords(collectionName, recordDefinition, mapper,
495-
new GetRecordOptions(true));
504+
String filter = SQLVectorStoreRecordCollectionSearchMapping
505+
.buildFilter(options.getBasicVectorSearchFilter(), recordDefinition);
506+
List<Object> parameters = SQLVectorStoreRecordCollectionSearchMapping
507+
.getFilterParameters(options.getBasicVectorSearchFilter());
508+
509+
List<Record> records = getRecordsWithFilter(collectionName, recordDefinition, mapper,
510+
new GetRecordOptions(true), filter, parameters);
496511
List<VectorSearchResult<Record>> results = new ArrayList<>();
497512

498513
DistanceFunction distanceFunction = vectorField.getDistanceFunction() == null

0 commit comments

Comments
 (0)