Skip to content

Commit ec31376

Browse files
fmeheustpsilberk
authored andcommitted
Test and bug fixes
1 parent b3e1f4b commit ec31376

1 file changed

Lines changed: 297 additions & 0 deletions

File tree

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,297 @@
1+
package com.microsoft.semantickernel.data.jdbc.oracle;
2+
3+
import com.microsoft.semantickernel.data.jdbc.JDBCVectorStore;
4+
import com.microsoft.semantickernel.data.jdbc.JDBCVectorStoreOptions;
5+
import com.microsoft.semantickernel.data.jdbc.JDBCVectorStoreRecordCollectionOptions;
6+
import com.microsoft.semantickernel.data.vectorsearch.VectorSearchFilter;
7+
import com.microsoft.semantickernel.data.vectorsearch.VectorSearchResult;
8+
import com.microsoft.semantickernel.data.vectorsearch.VectorSearchResults;
9+
import com.microsoft.semantickernel.data.vectorstorage.VectorStoreRecordCollection;
10+
import com.microsoft.semantickernel.data.vectorstorage.options.VectorSearchOptions;
11+
import org.junit.jupiter.api.BeforeEach;
12+
import org.junit.jupiter.params.ParameterizedTest;
13+
import org.junit.jupiter.params.provider.Arguments;
14+
import org.junit.jupiter.params.provider.MethodSource;
15+
import java.math.BigDecimal;
16+
import java.nio.charset.StandardCharsets;
17+
import java.time.OffsetDateTime;
18+
import java.util.Arrays;
19+
import java.util.UUID;
20+
import java.util.stream.Stream;
21+
22+
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
23+
import static org.junit.jupiter.api.Assertions.assertEquals;
24+
import static org.junit.jupiter.api.Assertions.assertTrue;
25+
26+
public class OracleVectorStoreDataTypeSearchTest extends OracleCommonVectorStoreRecordCollectionTest {
27+
private static final double MIN_NUMBER = 1.0E-130;
28+
private static final BigDecimal BIG_NUMBER = BigDecimal.valueOf(9999999999999999.99);
29+
30+
31+
32+
@ParameterizedTest
33+
@MethodSource("supportedDataTypes")
34+
void testDataTypesSearch (ClassWithAllBoxedTypes record) {
35+
VectorStoreRecordCollection<String, ClassWithAllBoxedTypes> collection = setupBoxed();
36+
37+
collection.upsertAsync(record, null).block();
38+
39+
// boolean
40+
VectorSearchResults<ClassWithAllBoxedTypes> results = collection.searchAsync(
41+
null,
42+
VectorSearchOptions.builder()
43+
.withVectorSearchFilter(
44+
VectorSearchFilter.builder()
45+
.equalTo("booleanValue", record.getBooleanValue()).build()
46+
).build()).block();
47+
48+
assertEquals(1, results.getTotalCount());
49+
assertEquals(record.getBooleanValue(), results.getResults().get(0).getRecord().getBooleanValue());
50+
51+
// byte
52+
results = collection.searchAsync(
53+
null,
54+
VectorSearchOptions.builder()
55+
.withVectorSearchFilter(
56+
VectorSearchFilter.builder()
57+
.equalTo("byteValue", record.getByteValue()).build()
58+
).build()).block();
59+
60+
assertEquals(1, results.getTotalCount());
61+
assertEquals(record.getByteValue(), results.getResults().get(0).getRecord().getByteValue());
62+
63+
// short
64+
results = collection.searchAsync(
65+
null,
66+
VectorSearchOptions.builder()
67+
.withVectorSearchFilter(
68+
VectorSearchFilter.builder()
69+
.equalTo("shortValue", record.getShortValue()).build()
70+
).build()).block();
71+
72+
assertEquals(1, results.getTotalCount());
73+
assertEquals(record.getShortValue(), results.getResults().get(0).getRecord().getShortValue());
74+
75+
// integer
76+
results = collection.searchAsync(
77+
null,
78+
VectorSearchOptions.builder()
79+
.withVectorSearchFilter(
80+
VectorSearchFilter.builder()
81+
.equalTo("integerValue", record.getIntegerValue()).build()
82+
).build()).block();
83+
84+
assertEquals(1, results.getTotalCount());
85+
assertEquals(record.getIntegerValue(), results.getResults().get(0).getRecord().getIntegerValue());
86+
87+
// long
88+
results = collection.searchAsync(
89+
null,
90+
VectorSearchOptions.builder()
91+
.withVectorSearchFilter(
92+
VectorSearchFilter.builder()
93+
.equalTo("longValue", record.getLongValue()).build()
94+
).build()).block();
95+
96+
assertEquals(1, results.getTotalCount());
97+
assertEquals(record.getLongValue(), results.getResults().get(0).getRecord().getLongValue());
98+
99+
// float
100+
results = collection.searchAsync(
101+
null,
102+
VectorSearchOptions.builder()
103+
.withVectorSearchFilter(
104+
VectorSearchFilter.builder()
105+
.equalTo("floatValue", record.getFloatValue()).build()
106+
).build()).block();
107+
108+
assertEquals(1, results.getTotalCount());
109+
assertEquals(record.getFloatValue(), results.getResults().get(0).getRecord().getFloatValue());
110+
111+
// double
112+
results = collection.searchAsync(
113+
null,
114+
VectorSearchOptions.builder()
115+
.withVectorSearchFilter(
116+
VectorSearchFilter.builder()
117+
.equalTo("doubleValue", record.getDoubleValue()).build()
118+
).build()).block();
119+
120+
assertEquals(1, results.getTotalCount());
121+
assertEquals(record.getDoubleValue(), results.getResults().get(0).getRecord().getDoubleValue());
122+
123+
// decimal
124+
results = collection.searchAsync(
125+
null,
126+
VectorSearchOptions.builder()
127+
.withVectorSearchFilter(
128+
VectorSearchFilter.builder()
129+
.equalTo("decimalValue", record.getDecimalValue()).build()
130+
).build()).block();
131+
132+
assertEquals(1, results.getTotalCount());
133+
if (record.getDecimalValue() != null) {
134+
assertEquals(0, record.getDecimalValue()
135+
.compareTo(results.getResults().get(0).getRecord().getDecimalValue()));
136+
} else {
137+
assertEquals(record.getDecimalValue(),
138+
results.getResults().get(0).getRecord().getDecimalValue());
139+
}
140+
141+
// offset date time
142+
results = collection.searchAsync(
143+
null,
144+
VectorSearchOptions.builder()
145+
.withVectorSearchFilter(
146+
VectorSearchFilter.builder()
147+
.equalTo("offsetDateTimeValue", record.getOffsetDateTimeValue()).build()
148+
).build()).block();
149+
150+
assertEquals(1, results.getTotalCount());
151+
if (record.getOffsetDateTimeValue() != null) {
152+
assertTrue(record.getOffsetDateTimeValue()
153+
.isEqual(results.getResults().get(0).getRecord().getOffsetDateTimeValue()));
154+
} else {
155+
assertEquals(record.getOffsetDateTimeValue(),
156+
results.getResults().get(0).getRecord().getOffsetDateTimeValue());
157+
}
158+
159+
// UUID
160+
results = collection.searchAsync(
161+
null,
162+
VectorSearchOptions.builder()
163+
.withVectorSearchFilter(
164+
VectorSearchFilter.builder()
165+
.equalTo("uuidValue", record.getUuidValue()).build()
166+
).build()).block();
167+
168+
assertEquals(1, results.getTotalCount());
169+
assertEquals(record.getUuidValue(), results.getResults().get(0).getRecord().getUuidValue());
170+
171+
// byte array
172+
results = collection.searchAsync(
173+
null,
174+
VectorSearchOptions.builder()
175+
.withVectorSearchFilter(
176+
VectorSearchFilter.builder()
177+
.equalTo("byteArrayValue", record.getByteArrayValue()).build()
178+
).build()).block();
179+
180+
assertEquals(1, results.getTotalCount());
181+
assertArrayEquals(record.getByteArrayValue(), results.getResults().get(0).getRecord().getByteArrayValue());
182+
183+
collection.deleteCollectionAsync().block();
184+
185+
}
186+
187+
188+
public VectorStoreRecordCollection<String, ClassWithAllBoxedTypes> setupBoxed() {
189+
OracleVectorStoreQueryProvider queryProvider = OracleVectorStoreQueryProvider.builder()
190+
.withDataSource(DATA_SOURCE)
191+
.build();
192+
193+
JDBCVectorStore vectorStore = JDBCVectorStore.builder()
194+
.withDataSource(DATA_SOURCE)
195+
.withOptions(JDBCVectorStoreOptions.builder()
196+
.withQueryProvider(queryProvider)
197+
.build())
198+
.build();
199+
200+
VectorStoreRecordCollection<String, ClassWithAllBoxedTypes> collection =
201+
vectorStore.getCollection("BoxedTypes",
202+
JDBCVectorStoreRecordCollectionOptions.<ClassWithAllBoxedTypes>builder()
203+
.withRecordClass(ClassWithAllBoxedTypes.class)
204+
.build()).createCollectionAsync().block();
205+
206+
collection.createCollectionAsync().block();
207+
208+
return collection;
209+
}
210+
211+
212+
private static Stream<Arguments> supportedDataTypes() {
213+
return Stream.of(
214+
Arguments.of(
215+
new ClassWithAllBoxedTypes(
216+
"ID1", true, (byte) 127, (short) 3, 321, 5L,
217+
3.14f, 3.14159265358d, new BigDecimal("12345.67"),
218+
OffsetDateTime.now(), UUID.randomUUID(), "abc".getBytes(StandardCharsets.UTF_8),
219+
Arrays.asList(1.0f, 2.6f),
220+
new Float[] { 0.5f, 3.2f, 7.1f, -4.0f, 2.8f, 10.0f, -1.3f, 5.5f }
221+
)
222+
),
223+
Arguments.of(
224+
new ClassWithAllBoxedTypes(
225+
"ID2", false, Byte.MIN_VALUE, Short.MIN_VALUE, Integer.MIN_VALUE, Long.MIN_VALUE,
226+
Float.MIN_VALUE, MIN_NUMBER, BigDecimal.valueOf(MIN_NUMBER),
227+
OffsetDateTime.now(), UUID.randomUUID(), new byte[] {Byte.MIN_VALUE, -10, 0, 10, Byte.MAX_VALUE},
228+
Arrays.asList(Float.MIN_VALUE, -10f, 0f, 10f, Float.MAX_VALUE),
229+
new Float[] { 0.5f, 3.2f, 7.1f, -4.0f, 2.8f, 10.0f, -1.3f, 5.5f }
230+
)
231+
),
232+
Arguments.of(
233+
new ClassWithAllBoxedTypes(
234+
"ID3", false, Byte.MAX_VALUE, Short.MAX_VALUE, Integer.MAX_VALUE, Long.MAX_VALUE,
235+
Float.MAX_VALUE, BIG_NUMBER.doubleValue(), BIG_NUMBER.subtract(BigDecimal.valueOf(0.01d)),
236+
OffsetDateTime.now(), UUID.randomUUID(), null,
237+
null,
238+
new Float[] { 0.5f, 3.2f, 7.1f, -4.0f, 2.8f, 10.0f, -1.3f, 5.5f }
239+
)
240+
),
241+
Arguments.of(
242+
new ClassWithAllBoxedTypes(
243+
"ID3", null, null, null, null, null,
244+
null, null, null,
245+
null, null, null,
246+
null,
247+
null
248+
)
249+
)
250+
);
251+
}
252+
253+
private static Stream<Arguments> supportedDataPrimitiveTypes() {
254+
return Stream.of(
255+
Arguments.of(
256+
new ClassWithAllPrimitiveTypes(
257+
"ID1", true, (byte) 127, (short) 3, 321, 5L,
258+
3.14f, 3.14159265358d, new BigDecimal("12345.67"),
259+
OffsetDateTime.now(), UUID.randomUUID(), "abc".getBytes(StandardCharsets.UTF_8),
260+
Arrays.asList(1.0f, 2.6f),
261+
new float[]{0.5f, 3.2f, 7.1f, -4.0f, 2.8f, 10.0f, -1.3f, 5.5f}
262+
)
263+
),
264+
Arguments.of(
265+
new ClassWithAllPrimitiveTypes(
266+
"ID2", false, Byte.MIN_VALUE, Short.MIN_VALUE, Integer.MIN_VALUE,
267+
Long.MIN_VALUE,
268+
Float.MIN_VALUE, MIN_NUMBER, BigDecimal.valueOf(MIN_NUMBER),
269+
OffsetDateTime.now(), UUID.randomUUID(),
270+
new byte[]{Byte.MIN_VALUE, -10, 0, 10, Byte.MAX_VALUE},
271+
Arrays.asList(Float.MIN_VALUE, -10f, 0f, 10f, Float.MAX_VALUE),
272+
new float[]{0.5f, 3.2f, 7.1f, -4.0f, 2.8f, 10.0f, -1.3f, 5.5f}
273+
)
274+
),
275+
Arguments.of(
276+
new ClassWithAllPrimitiveTypes(
277+
"ID3", false, Byte.MAX_VALUE, Short.MAX_VALUE, Integer.MAX_VALUE,
278+
Long.MAX_VALUE,
279+
Float.MAX_VALUE, BIG_NUMBER.doubleValue(),
280+
BIG_NUMBER.subtract(BigDecimal.valueOf(0.01d)),
281+
OffsetDateTime.now(), UUID.randomUUID(), null,
282+
null,
283+
new float[]{0.5f, 3.2f, 7.1f, -4.0f, 2.8f, 10.0f, -1.3f, 5.5f}
284+
)
285+
),
286+
Arguments.of(
287+
new ClassWithAllPrimitiveTypes(
288+
"ID3", false, (byte) 0, (short) 0, 0, 0l,
289+
0f, 0d, null,
290+
null, null, null,
291+
null,
292+
null
293+
)
294+
)
295+
);
296+
}
297+
}

0 commit comments

Comments
 (0)