|
1 | 1 | package com.microsoft.semantickernel.data.jdbc.oracle; |
2 | 2 |
|
| 3 | +import com.fasterxml.jackson.annotation.JsonCreator; |
| 4 | +import com.fasterxml.jackson.annotation.JsonProperty; |
3 | 5 | import com.microsoft.semantickernel.data.VolatileVectorStoreRecordCollection; |
4 | 6 | import com.microsoft.semantickernel.data.VolatileVectorStoreRecordCollectionOptions; |
5 | 7 | import com.microsoft.semantickernel.data.jdbc.JDBCVectorStore; |
|
16 | 18 | import com.microsoft.semantickernel.data.vectorstorage.definition.VectorStoreRecordKeyField; |
17 | 19 | import com.microsoft.semantickernel.data.vectorstorage.definition.VectorStoreRecordVectorField; |
18 | 20 | import com.microsoft.semantickernel.data.vectorstorage.options.VectorSearchOptions; |
| 21 | + |
19 | 22 | import oracle.jdbc.OracleConnection; |
20 | 23 | import oracle.jdbc.datasource.impl.OracleDataSource; |
21 | 24 | import org.junit.jupiter.api.BeforeAll; |
|
33 | 36 | import java.sql.Statement; |
34 | 37 | import java.time.Duration; |
35 | 38 | import java.util.Arrays; |
| 39 | +import java.util.HashMap; |
36 | 40 | import java.util.List; |
| 41 | +import java.util.UUID; |
37 | 42 | import java.util.logging.Logger; |
38 | 43 | import java.util.stream.Collectors; |
39 | 44 | import java.util.stream.Stream; |
@@ -349,6 +354,69 @@ public void searchWithTagFilter() { |
349 | 354 | assertEquals(hotels.get(1).getId(), results.get(0).getRecord().getId()); |
350 | 355 | } |
351 | 356 |
|
| 357 | + @ParameterizedTest |
| 358 | + @MethodSource("supportedKeyTypes") |
| 359 | + <T> void testKeyTypes(String suffix, Class<?> keyType, Object keyValue) { |
| 360 | + VectorStoreRecordKeyField keyField = VectorStoreRecordKeyField.builder() |
| 361 | + .withName("id") |
| 362 | + .withStorageName("id") |
| 363 | + .withFieldType(keyType) |
| 364 | + .build(); |
| 365 | + |
| 366 | + VectorStoreRecordDataField dummyField = VectorStoreRecordDataField.builder() |
| 367 | + .withName("dummy") |
| 368 | + .withStorageName("dummy") |
| 369 | + .withFieldType(String.class) |
| 370 | + .build(); |
| 371 | + |
| 372 | + VectorStoreRecordVectorField dummyVector = VectorStoreRecordVectorField.builder() |
| 373 | + .withName("vec") |
| 374 | + .withStorageName("vec") |
| 375 | + .withFieldType(List.class) |
| 376 | + .withDimensions(2) |
| 377 | + .withDistanceFunction(DistanceFunction.EUCLIDEAN_DISTANCE) |
| 378 | + .withIndexKind(IndexKind.UNDEFINED) |
| 379 | + .build(); |
| 380 | + |
| 381 | + VectorStoreRecordDefinition definition = VectorStoreRecordDefinition.fromFields( |
| 382 | + Arrays.asList(keyField, dummyField, dummyVector) |
| 383 | + ); |
| 384 | + |
| 385 | + OracleVectorStoreQueryProvider queryProvider = OracleVectorStoreQueryProvider.builder() |
| 386 | + .withDataSource(DATA_SOURCE) |
| 387 | + .build(); |
| 388 | + |
| 389 | + JDBCVectorStore vectorStore = JDBCVectorStore.builder() |
| 390 | + .withDataSource(DATA_SOURCE) |
| 391 | + .withOptions(JDBCVectorStoreOptions.builder() |
| 392 | + .withQueryProvider(queryProvider) |
| 393 | + .build()) |
| 394 | + .build(); |
| 395 | + |
| 396 | + String collectionName = "test_keytype_" + suffix; |
| 397 | + |
| 398 | + VectorStoreRecordCollection collectionRaw = |
| 399 | + vectorStore.getCollection(collectionName, |
| 400 | + JDBCVectorStoreRecordCollectionOptions.<DummyRecordForKeyTypes>builder() |
| 401 | + .withRecordClass(DummyRecordForKeyTypes.class) |
| 402 | + .withRecordDefinition(definition) |
| 403 | + .build()); |
| 404 | + |
| 405 | + VectorStoreRecordCollection<Object, DummyRecordForKeyTypes> collection = |
| 406 | + (VectorStoreRecordCollection<Object, DummyRecordForKeyTypes>) collectionRaw; |
| 407 | + |
| 408 | + collection.createCollectionAsync().block(); |
| 409 | + |
| 410 | + DummyRecordForKeyTypes record = new DummyRecordForKeyTypes(keyValue, "dummyValue", Arrays.asList(1.0f, 2.0f)); |
| 411 | + collection.upsertAsync(record, null).block(); |
| 412 | + |
| 413 | + DummyRecordForKeyTypes result = collection.getAsync(keyValue, null).block(); |
| 414 | + assertNotNull(result); |
| 415 | + assertEquals("dummyValue", result.getDummy()); |
| 416 | + |
| 417 | + collection.deleteCollectionAsync().block(); |
| 418 | + } |
| 419 | + |
352 | 420 | @Nested |
353 | 421 | class HNSWIndexTests { |
354 | 422 | @Test |
@@ -517,4 +585,48 @@ private static Stream<Arguments> parametersExactSearch() { |
517 | 585 | Arguments.of (DistanceFunction.UNDEFINED, Arrays.asList(0.1000d, 18.9081d, 19.9669d)) |
518 | 586 | ); |
519 | 587 | } |
| 588 | + |
| 589 | + // commented out temporarily because only String type key is supported in |
| 590 | + // JDBCVectorStoreRecordCollection<Record>#getKeyFromRecord: |
| 591 | + // ... |
| 592 | + // return (String) keyField.get(data); |
| 593 | + // ... |
| 594 | + // thus upsertAync/getAsync won't work |
| 595 | + private static Stream<Arguments> supportedKeyTypes() { |
| 596 | + return Stream.of( |
| 597 | + Arguments.of("string", String.class, "asd123")/*, |
| 598 | + Arguments.of("integer", Integer.class, 321), |
| 599 | + Arguments.of("long", Long.class, 5L), |
| 600 | + Arguments.of("short", Short.class, (short) 3), |
| 601 | + Arguments.of("uuid", UUID.class, UUID.randomUUID())*/ |
| 602 | + ); |
| 603 | + } |
| 604 | + |
| 605 | + private static class DummyRecordForKeyTypes { |
| 606 | + private final Object id; |
| 607 | + private final String dummy; |
| 608 | + private final List<Float> vec; |
| 609 | + @JsonCreator |
| 610 | + public DummyRecordForKeyTypes( |
| 611 | + @JsonProperty("id")Object id, |
| 612 | + @JsonProperty("dummy") String dummy, |
| 613 | + @JsonProperty("vec") List<Float> vec) { |
| 614 | + this.id = id; |
| 615 | + this.dummy = dummy; |
| 616 | + this.vec = vec; |
| 617 | + } |
| 618 | + |
| 619 | + public Object getId() { |
| 620 | + return id; |
| 621 | + } |
| 622 | + |
| 623 | + public String getDummy() { |
| 624 | + return dummy; |
| 625 | + } |
| 626 | + |
| 627 | + @Override |
| 628 | + public String toString() { |
| 629 | + return String.valueOf(id); |
| 630 | + } |
| 631 | + } |
520 | 632 | } |
0 commit comments