|
29 | 29 | import org.junit.jupiter.params.provider.Arguments; |
30 | 30 | import org.junit.jupiter.params.provider.EnumSource; |
31 | 31 | import org.junit.jupiter.params.provider.MethodSource; |
| 32 | +import java.math.BigDecimal; |
| 33 | +import java.nio.charset.StandardCharsets; |
32 | 34 | import java.sql.Connection; |
33 | 35 | import java.sql.PreparedStatement; |
34 | 36 | import java.sql.ResultSet; |
35 | 37 | import java.sql.SQLException; |
36 | 38 | import java.sql.Statement; |
37 | 39 | import java.time.Duration; |
| 40 | +import java.time.OffsetDateTime; |
38 | 41 | import java.util.Arrays; |
39 | 42 | import java.util.HashMap; |
40 | 43 | import java.util.List; |
| 44 | +import java.util.Map; |
41 | 45 | import java.util.UUID; |
42 | 46 | import java.util.logging.Logger; |
43 | 47 | import java.util.stream.Collectors; |
44 | 48 | import java.util.stream.Stream; |
45 | 49 |
|
| 50 | +import static org.junit.jupiter.api.Assertions.assertArrayEquals; |
46 | 51 | import static org.junit.jupiter.api.Assertions.assertEquals; |
47 | 52 | import static org.junit.jupiter.api.Assertions.assertNotNull; |
48 | 53 | import static org.junit.jupiter.api.Assertions.assertNull; |
@@ -417,6 +422,87 @@ <T> void testKeyTypes(String suffix, Class<?> keyType, Object keyValue) { |
417 | 422 | collection.deleteCollectionAsync().block(); |
418 | 423 | } |
419 | 424 |
|
| 425 | + @ParameterizedTest |
| 426 | + @MethodSource("supportedDataTypes") |
| 427 | + void testDataTypes(String dataFieldName, Class<?> dataFieldType, Object dataFieldValue, Class<?> fieldSubType) { |
| 428 | + VectorStoreRecordKeyField keyField = VectorStoreRecordKeyField.builder() |
| 429 | + .withName("id") |
| 430 | + .withStorageName("id") |
| 431 | + .withFieldType(String.class) |
| 432 | + .build(); |
| 433 | + |
| 434 | + VectorStoreRecordDataField dataField; |
| 435 | + if (fieldSubType != null) { |
| 436 | + dataField = VectorStoreRecordDataField.builder() |
| 437 | + .withName("dummy") |
| 438 | + .withStorageName("dummy") |
| 439 | + .withFieldType(dataFieldType, fieldSubType) |
| 440 | + .isFilterable(true) |
| 441 | + .build(); |
| 442 | + } else { |
| 443 | + dataField = VectorStoreRecordDataField.builder() |
| 444 | + .withName("dummy") |
| 445 | + .withStorageName("dummy") |
| 446 | + .withFieldType(dataFieldType) |
| 447 | + .isFilterable(true) |
| 448 | + .build(); |
| 449 | + } |
| 450 | + |
| 451 | + VectorStoreRecordVectorField dummyVector = VectorStoreRecordVectorField.builder() |
| 452 | + .withName("vec") |
| 453 | + .withStorageName("vec") |
| 454 | + .withFieldType(List.class) |
| 455 | + .withDimensions(2) |
| 456 | + .withDistanceFunction(DistanceFunction.EUCLIDEAN_DISTANCE) |
| 457 | + .withIndexKind(IndexKind.UNDEFINED) |
| 458 | + .build(); |
| 459 | + |
| 460 | + VectorStoreRecordDefinition definition = VectorStoreRecordDefinition.fromFields( |
| 461 | + Arrays.asList(keyField, dataField, dummyVector) |
| 462 | + ); |
| 463 | + |
| 464 | + OracleVectorStoreQueryProvider queryProvider = OracleVectorStoreQueryProvider.builder() |
| 465 | + .withDataSource(DATA_SOURCE) |
| 466 | + .build(); |
| 467 | + |
| 468 | + JDBCVectorStore vectorStore = JDBCVectorStore.builder() |
| 469 | + .withDataSource(DATA_SOURCE) |
| 470 | + .withOptions(JDBCVectorStoreOptions.builder() |
| 471 | + .withQueryProvider(queryProvider) |
| 472 | + .build()) |
| 473 | + .build(); |
| 474 | + |
| 475 | + String collectionName = "test_datatype_" + dataFieldName; |
| 476 | + |
| 477 | + VectorStoreRecordCollection<String, DummyRecordForDataTypes> collection = |
| 478 | + vectorStore.getCollection(collectionName, |
| 479 | + JDBCVectorStoreRecordCollectionOptions.<DummyRecordForDataTypes> builder() |
| 480 | + .withRecordClass(DummyRecordForDataTypes.class) |
| 481 | + .withRecordDefinition(definition).build()); |
| 482 | + |
| 483 | + collection.createCollectionAsync().block(); |
| 484 | + |
| 485 | + String key = "testid"; |
| 486 | + |
| 487 | + DummyRecordForDataTypes record = |
| 488 | + new DummyRecordForDataTypes(key, dataFieldValue, Arrays.asList(1.0f, 2.0f)); |
| 489 | + |
| 490 | + collection.upsertAsync(record, null).block(); |
| 491 | + |
| 492 | + DummyRecordForDataTypes result = collection.getAsync(key, null).block(); |
| 493 | + assertNotNull(result); |
| 494 | + |
| 495 | + if (dataFieldValue instanceof Number && result.getDummy() instanceof Number) { |
| 496 | + assertEquals(((Number) dataFieldValue).doubleValue(), ((Number) result.getDummy()).doubleValue()); |
| 497 | + } else if (dataFieldValue instanceof byte[]) { |
| 498 | + assertArrayEquals((byte[]) dataFieldValue, (byte[]) result.getDummy()); |
| 499 | + } else { |
| 500 | + assertEquals(dataFieldValue, result.getDummy()); |
| 501 | + } |
| 502 | + |
| 503 | + collection.deleteCollectionAsync().block(); |
| 504 | + } |
| 505 | + |
420 | 506 | @Nested |
421 | 507 | class HNSWIndexTests { |
422 | 508 | @Test |
@@ -602,6 +688,25 @@ private static Stream<Arguments> supportedKeyTypes() { |
602 | 688 | ); |
603 | 689 | } |
604 | 690 |
|
| 691 | + private static Stream<Arguments> supportedDataTypes() { |
| 692 | + return Stream.of( |
| 693 | + Arguments.of("string", String.class, "asd123", null), |
| 694 | + Arguments.of("boolean_true", Boolean.class, true, null), |
| 695 | + Arguments.of("boolean_false", Boolean.class, false, null), |
| 696 | + Arguments.of("byte", Byte.class, (byte) 127, null), |
| 697 | + Arguments.of("short", Short.class, (short) 3, null), |
| 698 | + Arguments.of("integer", Integer.class, 321, null), |
| 699 | + Arguments.of("long", Long.class, 5L, null), |
| 700 | + Arguments.of("float", Float.class, 3.14f, null), |
| 701 | + Arguments.of("double", double.class, 3.14159265358d, null), |
| 702 | + Arguments.of("decimal", BigDecimal.class, new BigDecimal("12345.67"), null), |
| 703 | + //Arguments.of("timestamp", OffsetDateTime.class, OffsetDateTime.now(), null) |
| 704 | + //Arguments.of("uuid", UUID.class, UUID.randomUUID(), null) |
| 705 | + Arguments.of("byte_array", byte[].class, "abc".getBytes(StandardCharsets.UTF_8), null), |
| 706 | + Arguments.of("json", List.class, Arrays.asList("a", "s", "d"), String.class) |
| 707 | + ); |
| 708 | + } |
| 709 | + |
605 | 710 | private static class DummyRecordForKeyTypes { |
606 | 711 | private final Object id; |
607 | 712 | private final String dummy; |
@@ -629,4 +734,32 @@ public String toString() { |
629 | 734 | return String.valueOf(id); |
630 | 735 | } |
631 | 736 | } |
| 737 | + |
| 738 | + private static class DummyRecordForDataTypes { |
| 739 | + private final String id; |
| 740 | + private final Object dummy; |
| 741 | + private final List<Float> vec; |
| 742 | + @JsonCreator |
| 743 | + public DummyRecordForDataTypes( |
| 744 | + @JsonProperty("id") String id, |
| 745 | + @JsonProperty("dummy") Object dummy, |
| 746 | + @JsonProperty("vec") List<Float> vec) { |
| 747 | + this.id = id; |
| 748 | + this.dummy = dummy; |
| 749 | + this.vec = vec; |
| 750 | + } |
| 751 | + |
| 752 | + public String getId() { |
| 753 | + return id; |
| 754 | + } |
| 755 | + |
| 756 | + public Object getDummy() { |
| 757 | + return dummy; |
| 758 | + } |
| 759 | + |
| 760 | + @Override |
| 761 | + public String toString() { |
| 762 | + return String.valueOf(id); |
| 763 | + } |
| 764 | + } |
632 | 765 | } |
0 commit comments