|
1 | 1 | package com.microsoft.semantickernel.tests.connectors.memory.jdbc; |
2 | 2 |
|
3 | | -import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreQueryProvider; |
| 3 | +import com.microsoft.semantickernel.connectors.data.jdbc.SQLVectorStoreQueryProvider; |
4 | 4 | import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreRecordCollection; |
5 | 5 | import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreRecordCollectionOptions; |
6 | 6 | import com.microsoft.semantickernel.connectors.data.mysql.MySQLVectorStoreQueryProvider; |
7 | 7 | import com.microsoft.semantickernel.connectors.data.postgres.PostgreSQLVectorStoreQueryProvider; |
| 8 | +import com.microsoft.semantickernel.data.vectorsearch.VectorSearchResult; |
8 | 9 | import com.microsoft.semantickernel.data.vectorstorage.options.GetRecordOptions; |
| 10 | +import com.microsoft.semantickernel.data.vectorstorage.options.VectorSearchOptions; |
9 | 11 | import com.microsoft.semantickernel.tests.connectors.memory.Hotel; |
10 | 12 | import com.mysql.cj.jdbc.MysqlDataSource; |
| 13 | +import org.junit.jupiter.api.Test; |
11 | 14 | import org.junit.jupiter.params.ParameterizedTest; |
12 | 15 | import org.junit.jupiter.params.provider.EnumSource; |
13 | 16 | import org.postgresql.ds.PGSimpleDataSource; |
@@ -44,7 +47,7 @@ public enum QueryProvider { |
44 | 47 | } |
45 | 48 |
|
46 | 49 | private JDBCVectorStoreRecordCollection<Hotel> buildRecordCollection(QueryProvider provider, @Nonnull String collectionName) { |
47 | | - JDBCVectorStoreQueryProvider queryProvider; |
| 50 | + SQLVectorStoreQueryProvider queryProvider; |
48 | 51 | DataSource dataSource; |
49 | 52 |
|
50 | 53 | switch (provider) { |
@@ -93,12 +96,19 @@ public void buildRecordCollection(QueryProvider provider) { |
93 | 96 | } |
94 | 97 |
|
95 | 98 | private List<Hotel> getHotels() { |
| 99 | + ArrayList<Hotel> embeddings = new ArrayList<>(); |
| 100 | + |
96 | 101 | return List.of( |
97 | | - new Hotel("id_1", "Hotel 1", 1, "Hotel 1 description", Arrays.asList(1.0f, 2.0f, 3.0f), 4.0), |
98 | | - new Hotel("id_2", "Hotel 2", 2, "Hotel 2 description", Arrays.asList(1.0f, 2.0f, 3.0f), 3.0), |
99 | | - new Hotel("id_3", "Hotel 3", 3, "Hotel 3 description", Arrays.asList(1.0f, 2.0f, 3.0f), 5.0), |
100 | | - new Hotel("id_4", "Hotel 4", 4, "Hotel 4 description", Arrays.asList(1.0f, 2.0f, 3.0f), 4.0), |
101 | | - new Hotel("id_5", "Hotel 5", 5, "Hotel 5 description", Arrays.asList(1.0f, 2.0f, 3.0f), 5.0) |
| 102 | + new Hotel("id_1", "Hotel 1", 1, "Hotel 1 description", Arrays.asList(0.5f, 3.2f, 7.1f, -4.0f, 2.8f, 10.0f, -1.3f, 5.5f), |
| 103 | + Arrays.asList(0.5f, 3.2f, 7.1f, -4.0f, 2.8f, 10.0f, -1.3f, 5.5f),4.0), |
| 104 | + new Hotel("id_2", "Hotel 2", 2, "Hotel 2 description", Arrays.asList(-2.0f, 8.1f, 0.9f, 5.4f, -3.3f, 2.2f, 9.9f, -4.5f), |
| 105 | + Arrays.asList(-2.0f, 8.1f, 0.9f, 5.4f, -3.3f, 2.2f, 9.9f, -4.5f),3.0), |
| 106 | + new Hotel("id_3", "Hotel 3", 3, "Hotel 3 description", Arrays.asList(4.5f, -6.2f, 3.1f, 7.7f, -0.8f, 1.1f, -2.2f, 8.3f), |
| 107 | + Arrays.asList(4.5f, -6.2f, 3.1f, 7.7f, -0.8f, 1.1f, -2.2f, 8.3f),5.0), |
| 108 | + new Hotel("id_4", "Hotel 4", 4, "Hotel 4 description", Arrays.asList(7.0f, 1.2f, -5.3f, 2.5f, 6.6f, -7.8f, 3.9f, -0.1f), |
| 109 | + Arrays.asList(7.0f, 1.2f, -5.3f, 2.5f, 6.6f, -7.8f, 3.9f, -0.1f),4.0), |
| 110 | + new Hotel("id_5", "Hotel 5", 5, "Hotel 5 description", Arrays.asList(-3.5f, 4.4f, -1.2f, 9.9f, 5.7f, -6.1f, 7.8f, -2.0f), |
| 111 | + Arrays.asList(-3.5f, 4.4f, -1.2f, 9.9f, 5.7f, -6.1f, 7.8f, -2.0f),5.0) |
102 | 112 | ); |
103 | 113 | } |
104 | 114 |
|
@@ -306,4 +316,80 @@ public void getBatchWithNoVectors(QueryProvider provider) { |
306 | 316 | assertNotNull(hotel.getDescriptionEmbedding()); |
307 | 317 | } |
308 | 318 | } |
| 319 | + |
| 320 | + |
| 321 | + @Test |
| 322 | + public void postgresExactSearch() { |
| 323 | + String collectionName = "search"; |
| 324 | + JDBCVectorStoreRecordCollection<Hotel> recordCollection = buildRecordCollection(QueryProvider.PostgreSQL, collectionName); |
| 325 | + |
| 326 | + List<Hotel> hotels = getHotels(); |
| 327 | + recordCollection.upsertBatchAsync(hotels, null).block(); |
| 328 | + |
| 329 | + // Embeddings similar to the third hotel |
| 330 | + List<Float> embeddings = Arrays.asList(4.5f, -6.2f, 3.1f, 7.7f, -0.8f, 1.1f, -2.2f, 8.2f); |
| 331 | + List<VectorSearchResult<Hotel>> results = recordCollection.searchAsync(embeddings, null).block(); |
| 332 | + assertNotNull(results); |
| 333 | + assertEquals(3, results.size()); |
| 334 | + // The third hotel should be the most similar |
| 335 | + assertEquals(hotels.get(2).getId(), results.get(0).getRecord().getId()); |
| 336 | + |
| 337 | + // Skip the first result |
| 338 | + results = recordCollection.searchAsync(embeddings, VectorSearchOptions.builder().withOffset(1).withLimit(-100).build()).block(); |
| 339 | + assertNotNull(results); |
| 340 | + assertEquals(1, results.size()); |
| 341 | + // The first hotel should be the most similar |
| 342 | + assertEquals(hotels.get(0).getId(), results.get(0).getRecord().getId()); |
| 343 | + } |
| 344 | + |
| 345 | + @Test |
| 346 | + public void postgresApproximateSearch() { |
| 347 | + String collectionName = "searchWithIndex"; |
| 348 | + JDBCVectorStoreRecordCollection<Hotel> recordCollection = buildRecordCollection(QueryProvider.PostgreSQL, collectionName); |
| 349 | + |
| 350 | + List<Hotel> hotels = getHotels(); |
| 351 | + recordCollection.upsertBatchAsync(hotels, null).block(); |
| 352 | + |
| 353 | + VectorSearchOptions options = VectorSearchOptions.builder() |
| 354 | + .withVectorFieldName("indexedDescriptionEmbedding") |
| 355 | + .withLimit(5) |
| 356 | + .build(); |
| 357 | + |
| 358 | + // Embeddings similar to the third hotel |
| 359 | + List<Float> embeddings = Arrays.asList(4.5f, -6.2f, 3.1f, 7.7f, -0.8f, 1.1f, -2.2f, 8.2f); |
| 360 | + List<VectorSearchResult<Hotel>> results = recordCollection.searchAsync(embeddings, options).block(); |
| 361 | + assertNotNull(results); |
| 362 | + assertEquals(5, results.size()); |
| 363 | + // The third hotel should be the most similar |
| 364 | + assertEquals(hotels.get(2).getId(), results.get(0).getRecord().getId()); |
| 365 | + } |
| 366 | + |
| 367 | + @Test |
| 368 | + public void searchIncludeAndNotIncludeVectors() { |
| 369 | + String collectionName = "searchIncludeAndNotIncludeVectors"; |
| 370 | + JDBCVectorStoreRecordCollection<Hotel> recordCollection = buildRecordCollection(QueryProvider.PostgreSQL, collectionName); |
| 371 | + |
| 372 | + List<Hotel> hotels = getHotels(); |
| 373 | + recordCollection.upsertBatchAsync(hotels, null).block(); |
| 374 | + |
| 375 | + // Embeddings similar to the third hotel |
| 376 | + List<Float> embeddings = Arrays.asList(4.5f, -6.2f, 3.1f, 7.7f, -0.8f, 1.1f, -2.2f, 8.2f); |
| 377 | + List<VectorSearchResult<Hotel>> results = recordCollection.searchAsync(embeddings, null).block(); |
| 378 | + assertNotNull(results); |
| 379 | + assertEquals(3, results.size()); |
| 380 | + // The third hotel should be the most similar |
| 381 | + assertEquals(hotels.get(2).getId(), results.get(0).getRecord().getId()); |
| 382 | + assertNull(results.get(0).getRecord().getDescriptionEmbedding()); |
| 383 | + |
| 384 | + VectorSearchOptions options = VectorSearchOptions.builder() |
| 385 | + .withIncludeVectors(true) |
| 386 | + .build(); |
| 387 | + |
| 388 | + results = recordCollection.searchAsync(embeddings, options).block(); |
| 389 | + assertNotNull(results); |
| 390 | + assertEquals(3, results.size()); |
| 391 | + // The third hotel should be the most similar |
| 392 | + assertEquals(hotels.get(2).getId(), results.get(0).getRecord().getId()); |
| 393 | + assertNotNull(results.get(0).getRecord().getDescriptionEmbedding()); |
| 394 | + } |
309 | 395 | } |
0 commit comments