55import com .microsoft .semantickernel .connectors .data .jdbc .JDBCVectorStoreRecordCollectionOptions ;
66import com .microsoft .semantickernel .connectors .data .mysql .MySQLVectorStoreQueryProvider ;
77import com .microsoft .semantickernel .connectors .data .postgres .PostgreSQLVectorStoreQueryProvider ;
8+ import com .microsoft .semantickernel .connectors .data .jdbc .filter .SQLEqualToFilterClause ;
9+ import com .microsoft .semantickernel .data .vectorsearch .VectorSearchFilter ;
810import com .microsoft .semantickernel .data .vectorsearch .VectorSearchResult ;
911import com .microsoft .semantickernel .data .vectorstorage .options .GetRecordOptions ;
1012import com .microsoft .semantickernel .data .vectorstorage .options .VectorSearchOptions ;
1113import com .microsoft .semantickernel .tests .connectors .memory .Hotel ;
1214import com .mysql .cj .jdbc .MysqlDataSource ;
1315import org .junit .jupiter .api .Test ;
1416import org .junit .jupiter .params .ParameterizedTest ;
17+ import org .junit .jupiter .params .provider .Arguments ;
1518import org .junit .jupiter .params .provider .EnumSource ;
19+ import org .junit .jupiter .params .provider .MethodSource ;
1620import org .postgresql .ds .PGSimpleDataSource ;
1721import org .testcontainers .containers .MySQLContainer ;
1822import org .testcontainers .containers .PostgreSQLContainer ;
2529import java .util .ArrayList ;
2630import java .util .Arrays ;
2731import java .util .List ;
32+ import java .util .stream .Stream ;
2833
2934import static org .junit .jupiter .api .Assertions .assertEquals ;
3035import static org .junit .jupiter .api .Assertions .assertNotNull ;
@@ -99,19 +104,23 @@ private List<Hotel> getHotels() {
99104 ArrayList <Hotel > embeddings = new ArrayList <>();
100105
101106 return List .of (
102- new Hotel ("id_1" , "Hotel 1" , 1 , "Hotel 1 description" , Arrays .asList (0.5f , 3.2f , 7.1f , -4.0f , 2.8f , 10.0f , -1.3f , 5.5f ),
103- Arrays .asList (0.5f , 3.2f , 7.1f , -4.0f , 2.8f , 10.0f , -1.3f , 5.5f ),4.0 ),
104- new Hotel ("id_2" , "Hotel 2" , 2 , "Hotel 2 description" , Arrays .asList (-2.0f , 8.1f , 0.9f , 5.4f , -3.3f , 2.2f , 9.9f , -4.5f ),
105- Arrays .asList (-2.0f , 8.1f , 0.9f , 5.4f , -3.3f , 2.2f , 9.9f , -4.5f ),3.0 ),
106- new Hotel ("id_3" , "Hotel 3" , 3 , "Hotel 3 description" , Arrays .asList (4.5f , -6.2f , 3.1f , 7.7f , -0.8f , 1.1f , -2.2f , 8.3f ),
107- Arrays .asList (4.5f , -6.2f , 3.1f , 7.7f , -0.8f , 1.1f , -2.2f , 8.3f ),5.0 ),
108- new Hotel ("id_4" , "Hotel 4" , 4 , "Hotel 4 description" , Arrays .asList (7.0f , 1.2f , -5.3f , 2.5f , 6.6f , -7.8f , 3.9f , -0.1f ),
109- Arrays .asList (7.0f , 1.2f , -5.3f , 2.5f , 6.6f , -7.8f , 3.9f , -0.1f ),4.0 ),
110- new Hotel ("id_5" , "Hotel 5" , 5 , "Hotel 5 description" , Arrays .asList (-3.5f , 4.4f , -1.2f , 9.9f , 5.7f , -6.1f , 7.8f , -2.0f ),
111- Arrays .asList (-3.5f , 4.4f , -1.2f , 9.9f , 5.7f , -6.1f , 7.8f , -2.0f ),5.0 )
107+ new Hotel ("id_1" , "Hotel 1" , 1 , "Hotel 1 description" , Arrays .asList (0.5f , 3.2f , 7.1f , -4.0f , 2.8f , 10.0f , -1.3f , 5.5f ),null , null , null , 4.0 ),
108+ new Hotel ("id_2" , "Hotel 2" , 2 , "Hotel 2 description" , Arrays .asList (-2.0f , 8.1f , 0.9f , 5.4f , -3.3f , 2.2f , 9.9f , -4.5f ),null , null , null , 4.0 ),
109+ new Hotel ("id_3" , "Hotel 3" , 3 , "Hotel 3 description" , Arrays .asList (4.5f , -6.2f , 3.1f , 7.7f , -0.8f , 1.1f , -2.2f , 8.3f ),null , null , null , 5.0 ),
110+ new Hotel ("id_4" , "Hotel 4" , 4 , "Hotel 4 description" , Arrays .asList (7.0f , 1.2f , -5.3f , 2.5f , 6.6f , -7.8f , 3.9f , -0.1f ),null , null , null , 4.0 ),
111+ new Hotel ("id_5" , "Hotel 5" , 5 , "Hotel 5 description" , Arrays .asList (-3.5f , 4.4f , -1.2f , 9.9f , 5.7f , -6.1f , 7.8f , -2.0f ),null , null , null , 4.0 )
112112 );
113113 }
114114
115+ /**
116+ * Search embeddings similar to the third hotel embeddings.
117+ * In order of similarity:
118+ * 1. Hotel 3
119+ * 2. Hotel 1
120+ * 3. Hotel 4
121+ */
122+ private static final List <Float > SEARCH_EMBEDDINGS = Arrays .asList (4.5f , -6.2f , 3.1f , 7.7f , -0.8f , 1.1f , -2.2f , 8.2f );
123+
115124 @ ParameterizedTest
116125 @ EnumSource (QueryProvider .class )
117126 public void upsertAndGetRecordAsync (QueryProvider provider ) {
@@ -263,7 +272,7 @@ public void getWithNoVectors(QueryProvider provider) {
263272 Hotel retrievedHotel = recordCollection .getAsync (hotel .getId (), options ).block ();
264273 assertNotNull (retrievedHotel );
265274 assertEquals (hotel .getId (), retrievedHotel .getId ());
266- assertNull (retrievedHotel .getDescriptionEmbedding ());
275+ assertNull (retrievedHotel .getEuclidean ());
267276 }
268277
269278 options = GetRecordOptions .builder ()
@@ -274,7 +283,7 @@ public void getWithNoVectors(QueryProvider provider) {
274283 Hotel retrievedHotel = recordCollection .getAsync (hotel .getId (), options ).block ();
275284 assertNotNull (retrievedHotel );
276285 assertEquals (hotel .getId (), retrievedHotel .getId ());
277- assertNotNull (retrievedHotel .getDescriptionEmbedding ());
286+ assertNotNull (retrievedHotel .getEuclidean ());
278287 }
279288 }
280289
@@ -301,7 +310,7 @@ public void getBatchWithNoVectors(QueryProvider provider) {
301310 assertEquals (hotels .size (), retrievedHotels .size ());
302311
303312 for (Hotel hotel : retrievedHotels ) {
304- assertNull (hotel .getDescriptionEmbedding ());
313+ assertNull (hotel .getEuclidean ());
305314 }
306315
307316 options = GetRecordOptions .builder ()
@@ -313,83 +322,127 @@ public void getBatchWithNoVectors(QueryProvider provider) {
313322 assertEquals (hotels .size (), retrievedHotels .size ());
314323
315324 for (Hotel hotel : retrievedHotels ) {
316- assertNotNull (hotel .getDescriptionEmbedding ());
325+ assertNotNull (hotel .getEuclidean ());
317326 }
318327 }
319328
329+ private static Stream <Arguments > provideSearchParameters () {
330+ return Stream .of (
331+ Arguments .of (QueryProvider .MySQL , "euclidean" ),
332+ Arguments .of (QueryProvider .MySQL , "cosineDistance" ),
333+ Arguments .of (QueryProvider .MySQL , "dotProduct" ),
334+ Arguments .of (QueryProvider .PostgreSQL , "euclidean" ),
335+ Arguments .of (QueryProvider .PostgreSQL , "cosineDistance" ),
336+ Arguments .of (QueryProvider .PostgreSQL , "dotProduct" )
337+ );
338+ }
320339
321- @ Test
322- public void postgresExactSearch () {
323- String collectionName = "search" ;
324- JDBCVectorStoreRecordCollection <Hotel > recordCollection = buildRecordCollection (QueryProvider .PostgreSQL , collectionName );
340+ @ ParameterizedTest
341+ @ MethodSource ("provideSearchParameters" )
342+ public void exactSearch (QueryProvider provider , String embeddingName ) {
343+ String collectionName = "search" + embeddingName ;
344+ JDBCVectorStoreRecordCollection <Hotel > recordCollection = buildRecordCollection (provider , collectionName );
325345
326346 List <Hotel > hotels = getHotels ();
327347 recordCollection .upsertBatchAsync (hotels , null ).block ();
328348
349+ VectorSearchOptions options = VectorSearchOptions .builder ()
350+ .withVectorFieldName (embeddingName )
351+ .withLimit (3 )
352+ .build ();
353+
329354 // Embeddings similar to the third hotel
330- List <Float > embeddings = Arrays .asList (4.5f , -6.2f , 3.1f , 7.7f , -0.8f , 1.1f , -2.2f , 8.2f );
331- List <VectorSearchResult <Hotel >> results = recordCollection .searchAsync (embeddings , null ).block ();
355+ List <VectorSearchResult <Hotel >> results = recordCollection .searchAsync (SEARCH_EMBEDDINGS , options ).block ();
332356 assertNotNull (results );
333357 assertEquals (3 , results .size ());
334358 // The third hotel should be the most similar
335359 assertEquals (hotels .get (2 ).getId (), results .get (0 ).getRecord ().getId ());
336360
361+ options = VectorSearchOptions .builder ()
362+ .withVectorFieldName (embeddingName )
363+ .withOffset (1 )
364+ .withLimit (-100 )
365+ .build ();
366+
337367 // Skip the first result
338- results = recordCollection .searchAsync (embeddings , VectorSearchOptions . builder (). withOffset ( 1 ). withLimit (- 100 ). build () ).block ();
368+ results = recordCollection .searchAsync (SEARCH_EMBEDDINGS , options ).block ();
339369 assertNotNull (results );
340370 assertEquals (1 , results .size ());
341371 // The first hotel should be the most similar
342372 assertEquals (hotels .get (0 ).getId (), results .get (0 ).getRecord ().getId ());
343373 }
344374
345- @ Test
346- public void postgresApproximateSearch () {
375+ @ ParameterizedTest
376+ @ EnumSource (QueryProvider .class )
377+ public void approximateSearch (QueryProvider provider ) {
347378 String collectionName = "searchWithIndex" ;
348- JDBCVectorStoreRecordCollection <Hotel > recordCollection = buildRecordCollection (QueryProvider . PostgreSQL , collectionName );
379+ JDBCVectorStoreRecordCollection <Hotel > recordCollection = buildRecordCollection (provider , collectionName );
349380
350381 List <Hotel > hotels = getHotels ();
351382 recordCollection .upsertBatchAsync (hotels , null ).block ();
352383
353384 VectorSearchOptions options = VectorSearchOptions .builder ()
354- .withVectorFieldName ("indexedDescriptionEmbedding " )
385+ .withVectorFieldName ("indexedEuclidean " )
355386 .withLimit (5 )
356387 .build ();
357388
358389 // Embeddings similar to the third hotel
359- List <Float > embeddings = Arrays .asList (4.5f , -6.2f , 3.1f , 7.7f , -0.8f , 1.1f , -2.2f , 8.2f );
360- List <VectorSearchResult <Hotel >> results = recordCollection .searchAsync (embeddings , options ).block ();
390+ List <VectorSearchResult <Hotel >> results = recordCollection .searchAsync (SEARCH_EMBEDDINGS , options ).block ();
361391 assertNotNull (results );
362392 assertEquals (5 , results .size ());
363393 // The third hotel should be the most similar
364394 assertEquals (hotels .get (2 ).getId (), results .get (0 ).getRecord ().getId ());
365395 }
366396
397+ @ ParameterizedTest
398+ @ MethodSource ("provideSearchParameters" )
399+ public void searchWithFilter (QueryProvider provider , String embeddingName ) {
400+ String collectionName = "searchWithFilter" ;
401+ JDBCVectorStoreRecordCollection <Hotel > recordCollection = buildRecordCollection (provider , collectionName );
402+
403+ List <Hotel > hotels = getHotels ();
404+ recordCollection .upsertBatchAsync (hotels , null ).block ();
405+
406+ VectorSearchOptions options = VectorSearchOptions .builder ()
407+ .withVectorFieldName (embeddingName )
408+ .withLimit (3 )
409+ .withVectorSearchFilter (
410+ VectorSearchFilter .builder ().withEqualToFilterClause (new SQLEqualToFilterClause ("rating" , 4.0 )).build ())
411+ .build ();
412+
413+ // Embeddings similar to the third hotel, but as the filter is set to 4.0, the third hotel should not be returned
414+ List <VectorSearchResult <Hotel >> results = recordCollection .searchAsync (SEARCH_EMBEDDINGS , options ).block ();
415+ assertNotNull (results );
416+ assertEquals (3 , results .size ());
417+ // The first hotel should be the most similar
418+ assertEquals (hotels .get (0 ).getId (), results .get (0 ).getRecord ().getId ());
419+ }
420+
421+ // MySQL will always return the vectors as they're needed to compute the distances
367422 @ Test
368- public void searchIncludeAndNotIncludeVectors () {
423+ public void postgresSearchIncludeAndNotIncludeVectors () {
369424 String collectionName = "searchIncludeAndNotIncludeVectors" ;
370425 JDBCVectorStoreRecordCollection <Hotel > recordCollection = buildRecordCollection (QueryProvider .PostgreSQL , collectionName );
371426
372427 List <Hotel > hotels = getHotels ();
373428 recordCollection .upsertBatchAsync (hotels , null ).block ();
374429
375- // Embeddings similar to the third hotel
376- List <Float > embeddings = Arrays .asList (4.5f , -6.2f , 3.1f , 7.7f , -0.8f , 1.1f , -2.2f , 8.2f );
377- List <VectorSearchResult <Hotel >> results = recordCollection .searchAsync (embeddings , null ).block ();
430+ List <VectorSearchResult <Hotel >> results = recordCollection .searchAsync (SEARCH_EMBEDDINGS , null ).block ();
378431 assertNotNull (results );
379432 assertEquals (3 , results .size ());
380433 // The third hotel should be the most similar
381434 assertEquals (hotels .get (2 ).getId (), results .get (0 ).getRecord ().getId ());
382- assertNull (results .get (0 ).getRecord ().getDescriptionEmbedding ());
435+ assertNull (results .get (0 ).getRecord ().getEuclidean ());
383436
384437 VectorSearchOptions options = VectorSearchOptions .builder ()
385438 .withIncludeVectors (true )
386439 .build ();
387440
388- results = recordCollection .searchAsync (embeddings , options ).block ();
441+ results = recordCollection .searchAsync (SEARCH_EMBEDDINGS , options ).block ();
389442 assertNotNull (results );
390443 assertEquals (3 , results .size ());
391444 // The third hotel should be the most similar
392445 assertEquals (hotels .get (2 ).getId (), results .get (0 ).getRecord ().getId ());
393- assertNotNull (results .get (0 ).getRecord ().getDescriptionEmbedding ());
446+ assertNotNull (results .get (0 ).getRecord ().getEuclidean ());
394447 }
395448}
0 commit comments