Skip to content

Commit 2f421a8

Browse files
fmeheustpsilberk
authored andcommitted
Any tag filter
1 parent 3cbee72 commit 2f421a8

2 files changed

Lines changed: 73 additions & 8 deletions

File tree

data/semantickernel-data-oracle/src/main/java/com/microsoft/semantickernel/data/jdbc/oracle/OracleVectorStoreQueryProvider.java

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@
44
import com.fasterxml.jackson.databind.JsonNode;
55
import com.fasterxml.jackson.databind.ObjectMapper;
66
import com.fasterxml.jackson.databind.node.ArrayNode;
7+
import com.microsoft.semantickernel.data.filter.AnyTagEqualToFilterClause;
8+
import com.microsoft.semantickernel.data.filter.EqualToFilterClause;
79
import com.microsoft.semantickernel.data.jdbc.*;
10+
import com.microsoft.semantickernel.data.vectorsearch.VectorSearchFilter;
811
import com.microsoft.semantickernel.data.vectorsearch.VectorSearchResult;
912
import com.microsoft.semantickernel.data.vectorsearch.VectorSearchResults;
1013
import com.microsoft.semantickernel.data.vectorstorage.VectorStoreRecordMapper;
@@ -34,6 +37,7 @@
3437
import java.time.OffsetDateTime;
3538
import java.util.ArrayList;
3639
import java.util.Collection;
40+
import java.util.Collections;
3741
import java.util.HashMap;
3842
import java.util.List;
3943
import java.util.Map;
@@ -387,13 +391,12 @@ public <Record> VectorSearchResults<Record> search(String collectionName, List<F
387391
PreparedStatement statement = connection.prepareStatement(selectQuery)) {
388392
// set parameters from filters
389393
int parameterIndex = 1;
390-
394+
System.out.println("Set vector parameter with index " + parameterIndex +" to: " + objectMapper.writeValueAsString(vector));
391395
statement.setString(parameterIndex++,
392396
objectMapper.writeValueAsString(vector));
393-
System.out.println("Set vector parameter to: " + objectMapper.writeValueAsString(vector));
394397
for (Object parameter : parameters) {
395-
statement.setObject(parameterIndex++, parameter);
396398
System.out.println("Set parameter " + parameterIndex + " to: " + parameter);
399+
statement.setObject(parameterIndex++, parameter);
397400
}
398401

399402
// Calls to defineColumnType reduce the number of network requests. When Oracle JDBC knows that it is
@@ -481,6 +484,44 @@ private String toOracleDistanceFunction(DistanceFunction distanceFunction) {
481484
}
482485
}
483486

487+
/**
488+
* Gets the filter parameters for the given vector search filter to associate with the filter
489+
* string generated by the getFilter method.
490+
*
491+
* @param filter The filter to get the filter parameters for.
492+
* @return The filter parameters.
493+
*/
494+
@Override
495+
public List<Object> getFilterParameters(VectorSearchFilter filter) {
496+
if (filter == null
497+
|| filter.getFilterClauses().isEmpty()) {
498+
return Collections.emptyList();
499+
}
500+
501+
return filter.getFilterClauses().stream().map(filterClause -> {
502+
if (filterClause instanceof EqualToFilterClause) {
503+
EqualToFilterClause equalToFilterClause = (EqualToFilterClause) filterClause;
504+
return equalToFilterClause.getValue();
505+
} else if (filterClause instanceof AnyTagEqualToFilterClause) {
506+
AnyTagEqualToFilterClause anyTagEqualToFilterClause = (AnyTagEqualToFilterClause) filterClause;
507+
return anyTagEqualToFilterClause.getValue();
508+
} else {
509+
throw new SKException("Unsupported filter clause type '"
510+
+ filterClause.getClass().getSimpleName() + "'.");
511+
}
512+
}).collect(Collectors.toList());
513+
}
514+
515+
@Override
516+
public String getAnyTagEqualToFilter(AnyTagEqualToFilterClause filterClause) {
517+
String fieldName = JDBCVectorStoreQueryProvider
518+
.validateSQLidentifier(filterClause.getFieldName());
519+
520+
return String.format("JSON_EXISTS(%s, '$[*]?(@ == $v_%s)' PASSING ? AS \"v_%s\")",
521+
fieldName, fieldName, fieldName);
522+
}
523+
524+
484525
public static Builder builder() {
485526
return new Builder();
486527
}

data/semantickernel-data-oracle/src/test/java/com/microsoft/semantickernel/data/jdbc/oracle/OracleVectorStoreRecordCollectionTest.java

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -127,19 +127,19 @@ public void clearCollection() {
127127

128128
private static List<Hotel> getHotels() {
129129
return Arrays.asList(
130-
new Hotel("id_1", "Hotel 1", 1, 1.49d, null, "Hotel 1 description",
130+
new Hotel("id_1", "Hotel 1", 1, 1.49d, Arrays.asList("one", "two"), "Hotel 1 description",
131131
Arrays.asList(0.5f, 3.2f, 7.1f, -4.0f, 2.8f, 10.0f, -1.3f, 5.5f), null, null, null,
132132
4.0),
133-
new Hotel("id_2", "Hotel 2", 2, 1.44d, null, "Hotel 2 description with free-text search",
133+
new Hotel("id_2", "Hotel 2", 2, 1.44d, Arrays.asList("three", "four"), "Hotel 2 description with free-text search",
134134
Arrays.asList(-2.0f, 8.1f, 0.9f, 5.4f, -3.3f, 2.2f, 9.9f, -4.5f), null, null, null,
135135
4.0),
136-
new Hotel("id_3", "Hotel 3", 3, 1.53d, null, "Hotel 3 description",
136+
new Hotel("id_3", "Hotel 3", 3, 1.53d, Arrays.asList("five", "six"), "Hotel 3 description",
137137
Arrays.asList(4.5f, -6.2f, 3.1f, 7.7f, -0.8f, 1.1f, -2.2f, 8.3f), null, null, null,
138138
5.0),
139-
new Hotel("id_4", "Hotel 4", 4, 1.35d, null, "Hotel 4 description",
139+
new Hotel("id_4", "Hotel 4", 4, 1.35d, Arrays.asList("seven", "eight"), "Hotel 4 description",
140140
Arrays.asList(7.0f, 1.2f, -5.3f, 2.5f, 6.6f, -7.8f, 3.9f, -0.1f), null, null, null,
141141
4.0),
142-
new Hotel("id_5", "Hotel 5", 5, 1.89d, null,"Hotel 5 description",
142+
new Hotel("id_5", "Hotel 5", 5, 1.89d, Arrays.asList("nine", "ten"),"Hotel 5 description",
143143
Arrays.asList(-3.5f, 4.4f, -1.2f, 9.9f, 5.7f, -6.1f, 7.8f, -2.0f), null, null, null,
144144
4.0));
145145
}
@@ -300,6 +300,30 @@ public void searchWithFilter(DistanceFunction distanceFunction, double expectedD
300300
assertEquals(results.get(0).getScore(), expectedDistance, 0.0001d);
301301
}
302302

303+
304+
@Test
305+
public void searchWithTagFilter() {
306+
List<Hotel> hotels = getHotels();
307+
recordCollection.upsertBatchAsync(hotels, null).block();
308+
309+
VectorSearchOptions options = VectorSearchOptions.builder()
310+
// .withVectorFieldName("")
311+
.withTop(3)
312+
.withVectorSearchFilter(
313+
VectorSearchFilter.builder()
314+
.anyTagEqualTo("tags", "three")
315+
.build())
316+
.build();
317+
318+
// Embeddings similar to the third hotel, but as the filter is set to 4.0, the third hotel should not be returned
319+
List<VectorSearchResult<Hotel>> results = recordCollection
320+
.searchAsync(SEARCH_EMBEDDINGS, options).block().getResults();
321+
assertNotNull(results);
322+
assertEquals(1, results.size());
323+
// The first hotel should be the most similar
324+
assertEquals(hotels.get(1).getId(), results.get(0).getRecord().getId());
325+
}
326+
303327
private static Stream<Arguments> distanceFunctionAndDistance() {
304328
return Stream.of(
305329
Arguments.of (DistanceFunction.COSINE_DISTANCE, 0.8548d),

0 commit comments

Comments
 (0)