Skip to content

Commit a9bc539

Browse files
fmeheustpsilberk
authored andcommitted
Refactoring
# Conflicts: # data/semantickernel-data-oracle/src/main/java/com/microsoft/semantickernel/data/jdbc/oracle/OracleVectorStoreQueryProvider.java
1 parent cf99cf7 commit a9bc539

7 files changed

Lines changed: 350 additions & 22 deletions

File tree

data/semantickernel-data-jdbc/src/main/java/com/microsoft/semantickernel/data/jdbc/JDBCVectorStoreRecordCollection.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,10 @@ public class JDBCVectorStoreRecordCollection<Record>
3434
implements SQLVectorStoreRecordCollection<String, Record> {
3535

3636
private final String collectionName;
37-
private final VectorStoreRecordDefinition recordDefinition;
38-
private final VectorStoreRecordMapper<Record, ResultSet> vectorStoreRecordMapper;
37+
protected final VectorStoreRecordDefinition recordDefinition;
38+
protected final VectorStoreRecordMapper<Record, ResultSet> vectorStoreRecordMapper;
3939
private final JDBCVectorStoreRecordCollectionOptions<Record> options;
40-
private final SQLVectorStoreQueryProvider queryProvider;
40+
protected final SQLVectorStoreQueryProvider queryProvider;
4141

4242
/**
4343
* Creates a new instance of the {@link JDBCVectorStoreRecordCollection}.

data/semantickernel-data-oracle/src/main/java/com/microsoft/semantickernel/data/jdbc/oracle/OracleDataTypesMapping.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
package com.microsoft.semantickernel.data.jdbc.oracle;
22

3+
/**
4+
* Defines oracle database type constants for supported field types.
5+
*/
36
public class OracleDataTypesMapping {
47
public static final String STRING_VARCHAR = "NVARCHAR2(%s)";
58
public static final String STRING_CLOB = "CLOB";
@@ -15,4 +18,5 @@ public class OracleDataTypesMapping {
1518
public static final String OFFSET_DATE_TIME = "TIMESTAMP(7) WITH TIME ZONE";
1619
public static final String UUID = "RAW(16)";
1720
public static final String JSON = "JSON";
21+
public static final String VECTOR_FLOAT = "VECTOR(%s, FLOAT32)";
1822
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,307 @@
1+
package com.microsoft.semantickernel.data.jdbc.oracle;
2+
3+
import com.microsoft.semantickernel.data.jdbc.oracle.OracleVectorStoreQueryProvider.StringTypeMapping;
4+
import com.microsoft.semantickernel.data.vectorstorage.definition.VectorStoreRecordDataField;
5+
import com.microsoft.semantickernel.data.vectorstorage.definition.VectorStoreRecordKeyField;
6+
import com.microsoft.semantickernel.data.vectorstorage.definition.VectorStoreRecordVectorField;
7+
import oracle.jdbc.OracleTypes;
8+
import java.math.BigDecimal;
9+
import java.time.OffsetDateTime;
10+
import java.util.Collection;
11+
import java.util.HashMap;
12+
import java.util.List;
13+
import java.util.Map;
14+
import java.util.UUID;
15+
import java.util.logging.Logger;
16+
import java.util.stream.Collectors;
17+
18+
/**
19+
* Helper class for field operations.
20+
*/
21+
public class OracleVectorStoreFieldHelper {
22+
private static final Logger LOGGER = Logger.getLogger(OracleVectorStoreQueryProvider.class.getName());
23+
24+
/**
25+
* Maps supported key java classes to Oracle database types
26+
*/
27+
private static final HashMap<Class<?>, String> supportedKeyTypes = new HashMap() {
28+
{
29+
put(String.class, String.format(OracleDataTypesMapping.STRING_VARCHAR, 255));
30+
put(short.class, OracleDataTypesMapping.SHORT);
31+
put(Short.class, OracleDataTypesMapping.SHORT);
32+
put(int.class, OracleDataTypesMapping.INTEGER);
33+
put(Integer.class, OracleDataTypesMapping.INTEGER);
34+
put(long.class, OracleDataTypesMapping.LONG);
35+
put(Long.class, OracleDataTypesMapping.LONG);
36+
put(UUID .class, OracleDataTypesMapping.UUID);
37+
}
38+
};
39+
40+
/**
41+
* Maps supported vector java classes to Oracle database types
42+
*/
43+
private static final Map<Class<?>, String> supportedVectorTypes = new HashMap() {
44+
{
45+
put(String.class, OracleDataTypesMapping.VECTOR_FLOAT);
46+
put(List.class, OracleDataTypesMapping.VECTOR_FLOAT);
47+
put(Collection.class, OracleDataTypesMapping.VECTOR_FLOAT);
48+
put(float[].class, OracleDataTypesMapping.VECTOR_FLOAT);
49+
put(Float[].class, OracleDataTypesMapping.VECTOR_FLOAT);
50+
/*
51+
put(byte[].class,"VECTOR(%s, INT8)");
52+
put(Byte[].class,"VECTOR(%s, INT8)");
53+
put(double[].class,"VECTOR(%s, FLOAT64)");
54+
put(Double[].class,"VECTOR(%s, FLOAT64)");
55+
put(boolean[].class,"VECTOR(%s, BINARY)");
56+
put(Boolean[].class,"VECTOR(%s, BINARY)");
57+
*/
58+
}
59+
};
60+
61+
/**
62+
* Maps supported data java classes to Oracle database types
63+
*/
64+
private static final HashMap<Class<?>, String> supportedDataTypes = new HashMap() {
65+
{
66+
put(byte.class, OracleDataTypesMapping.BYTE);
67+
put(Byte.class, OracleDataTypesMapping.BYTE);
68+
put(short.class, OracleDataTypesMapping.SHORT);
69+
put(Short.class, OracleDataTypesMapping.SHORT);
70+
put(int.class, OracleDataTypesMapping.INTEGER);
71+
put(Integer.class, OracleDataTypesMapping.INTEGER);
72+
put(long.class, OracleDataTypesMapping.LONG);
73+
put(Long.class, OracleDataTypesMapping.LONG);
74+
put(Float.class, OracleDataTypesMapping.FLOAT);
75+
put(float.class, OracleDataTypesMapping.FLOAT);
76+
put(Double.class, OracleDataTypesMapping.DOUBLE);
77+
put(double.class, OracleDataTypesMapping.DOUBLE);
78+
put(BigDecimal.class, OracleDataTypesMapping.DECIMAL);
79+
put(Boolean.class, OracleDataTypesMapping.BOOLEAN);
80+
put(boolean.class, OracleDataTypesMapping.BOOLEAN);
81+
put(OffsetDateTime.class, OracleDataTypesMapping.OFFSET_DATE_TIME);
82+
put(UUID.class, OracleDataTypesMapping.UUID);
83+
put(byte[].class, OracleDataTypesMapping.BYTE_ARRAY);
84+
put(List.class, OracleDataTypesMapping.JSON);
85+
}
86+
87+
};
88+
89+
/**
90+
* Maps vector type to OracleTypes. Only needed if types other than FLOAT_32 are supported.
91+
*/
92+
private static final Map<Class<?>, Integer> mapOracleTypeToVector = new HashMap() {
93+
{
94+
put(float[].class, OracleTypes.VECTOR_FLOAT32);
95+
put(Float[].class, OracleTypes.VECTOR_FLOAT32);
96+
/*
97+
put(byte[].class, OracleTypes.VECTOR_INT8);
98+
put(Byte[].class, OracleTypes.VECTOR_INT8);
99+
put(Double[].class, OracleTypes.VECTOR_FLOAT64);
100+
put(double[].class, OracleTypes.VECTOR_FLOAT64);
101+
put(Boolean[].class, OracleTypes.VECTOR_BINARY);
102+
put(boolean[].class, OracleTypes.VECTOR_BINARY);
103+
*/
104+
}
105+
};
106+
107+
/**
108+
* Gets the mapping between the supported Java key types and the Oracle database type.
109+
*
110+
* @return the mapping between the supported Java key types and the Oracle database type.
111+
*/
112+
public static HashMap<Class<?>, String> getSupportedKeyTypes() {
113+
return supportedKeyTypes;
114+
}
115+
116+
/**
117+
* Gets the mapping between the supported Java data types and the Oracle database type.
118+
*
119+
* @return the mapping between the supported Java data types and the Oracle database type.
120+
*/
121+
public static Map<Class<?>, String> getSupportedDataTypes(
122+
StringTypeMapping stringTypeMapping, int defaultVarCharLength) {
123+
124+
if (stringTypeMapping.equals(StringTypeMapping.USE_VARCHAR)) {
125+
supportedDataTypes.put(String.class, String.format(OracleDataTypesMapping.STRING_VARCHAR, defaultVarCharLength));
126+
} else {
127+
supportedDataTypes.put(String.class, OracleDataTypesMapping.STRING_CLOB);
128+
}
129+
return supportedDataTypes;
130+
}
131+
132+
/**
133+
* Gets the mapping between the supported Java data types and the Oracle database type.
134+
*
135+
* @return the mapping between the supported Java data types and the Oracle database type.
136+
*/
137+
public static Map<Class<?>, String> getSupportedVectorTypes() {
138+
return supportedVectorTypes;
139+
}
140+
141+
/**
142+
* Generates the statement to create the index according to the vector field definition.
143+
*
144+
* @return the CREATE VECTOR INDEX statement to create the index according to the vector
145+
* field definition.
146+
*/
147+
public static String getCreateVectorIndexStatement(VectorStoreRecordVectorField field, String collectionTableName) {
148+
switch (field.getIndexKind()) {
149+
case IVFFLAT:
150+
return "CREATE VECTOR INDEX IF NOT EXISTS "
151+
+ getIndexName(field.getEffectiveStorageName())
152+
+ " ON "
153+
+ collectionTableName + "( " + field.getEffectiveStorageName() + " ) "
154+
+ " ORGANIZATION NEIGHBOR PARTITIONS "
155+
+ " WITH DISTANCE COSINE "
156+
+ "PARAMETERS ( TYPE IVF )";
157+
case HNSW:
158+
return "CREATE VECTOR INDEX IF NOT EXISTS " + getIndexName(field.getEffectiveStorageName())
159+
+ " ON "
160+
+ collectionTableName + "( " + field.getEffectiveStorageName() + " ) "
161+
+ "ORGANIZATION INMEMORY GRAPH "
162+
+ "WITH DISTANCE COSINE "
163+
+ "PARAMETERS (TYPE HNSW)";
164+
case UNDEFINED:
165+
return null;
166+
default:
167+
LOGGER.warning("Unsupported index kind: " + field.getIndexKind());
168+
return null;
169+
}
170+
}
171+
172+
/**
173+
* Generates the statement to create the index according to the field definition.
174+
*
175+
* @return the CREATE INDEX statement to create the index according to the field definition.
176+
*/
177+
public static String createIndexForDataField(String collectionTableName, VectorStoreRecordDataField dataField, Map<Class<?>, String> supportedDataTypes) {
178+
if (supportedDataTypes.get(dataField.getFieldType()) == "JSON") {
179+
String dataFieldIndex = "CREATE MULTIVALUE INDEX %s ON %s t (t.%s.%s)";
180+
return String.format(dataFieldIndex,
181+
collectionTableName + "_" + dataField.getEffectiveStorageName(),
182+
collectionTableName,
183+
dataField.getEffectiveStorageName(),
184+
getFunctionForType(supportedDataTypes.get(dataField.getFieldSubType())));
185+
} else {
186+
String dataFieldIndex = "CREATE INDEX %s ON %s (%s ASC)";
187+
return String.format(dataFieldIndex,
188+
collectionTableName + "_" + dataField.getEffectiveStorageName(),
189+
collectionTableName,
190+
dataField.getEffectiveStorageName()
191+
);
192+
}
193+
}
194+
195+
/**
196+
* Gets the function that allows to return the function that converts the JSON value to the
197+
* data type.
198+
* @param jdbcType The JDBC type.
199+
* @return the function that allows to return the function that converts the JSON value to the
200+
* data type.
201+
*/
202+
private static String getFunctionForType(String jdbcType) {
203+
switch (jdbcType) {
204+
case OracleDataTypesMapping.BOOLEAN:
205+
return "boolean()";
206+
case OracleDataTypesMapping.BYTE:
207+
case OracleDataTypesMapping.SHORT:
208+
case OracleDataTypesMapping.INTEGER:
209+
case OracleDataTypesMapping.LONG:
210+
case OracleDataTypesMapping.FLOAT:
211+
case OracleDataTypesMapping.DOUBLE:
212+
case OracleDataTypesMapping.DECIMAL:
213+
return "numberOnly()";
214+
case OracleDataTypesMapping.OFFSET_DATE_TIME:
215+
return "timestamp()";
216+
default:
217+
return "string()";
218+
}
219+
}
220+
221+
/**
222+
* Gets the type of the vector given the field definition. This method is not needed if only
223+
*
224+
* @param field the vector field definition.
225+
* @return returns the type of vector for the given field type.
226+
*/
227+
public static String getTypeForVectorField(VectorStoreRecordVectorField field) {
228+
String dimension = field.getDimensions() > 0 ? String.valueOf(field.getDimensions()) : "*";
229+
return String.format(supportedVectorTypes.get(field.getFieldType()), dimension);
230+
/* Not needed since all types are FLOAT32
231+
if (field.getFieldSubType() != null) {
232+
String vectorType;
233+
switch (field.getFieldSubType().getName()) {
234+
case "java.lang.Double":
235+
vectorType = "FLOAT64";
236+
break;
237+
case "java.lang.Byte":
238+
vectorType = "INT8";
239+
break;
240+
case "java.lang.Boolean":
241+
vectorType = "BINARY";
242+
break;
243+
default:
244+
vectorType = "FLOAT32";
245+
}
246+
return String.format(supportedVectorTypes.get(field.getFieldType()), dimension, vectorType);
247+
} else {
248+
return String.format(supportedVectorTypes.get(field.getFieldType()), dimension);
249+
}
250+
*/
251+
}
252+
253+
/**
254+
* Gets the JDBC oracle of the vector field definition.
255+
* @param field the vector field definition.
256+
* @return the JDBC oracle type.
257+
*/
258+
public static int getOracleTypeForField(VectorStoreRecordVectorField field) {
259+
if (field.getFieldSubType() == null) {
260+
return mapOracleTypeToVector.get(field.getFieldType()).intValue();
261+
} else {
262+
switch (field.getFieldSubType().getName()) {
263+
case "java.lang.Double":
264+
return OracleTypes.VECTOR_FLOAT64;
265+
case "java.lang.Byte":
266+
return OracleTypes.VECTOR_INT8;
267+
case "java.lang.Boolean":
268+
return OracleTypes.VECTOR_BINARY;
269+
default:
270+
return OracleTypes.VECTOR_FLOAT32;
271+
}
272+
}
273+
}
274+
275+
/**
276+
* Generates the index name given the field name. by suffixing "_VECTOR_INDEX" to the field name.
277+
* @param effectiveStorageName the field name.
278+
* @return the index name.
279+
*/
280+
private static String getIndexName(String effectiveStorageName) {
281+
return effectiveStorageName + "_VECTOR_INDEX";
282+
}
283+
284+
/**
285+
* Returns vector columns names and types for CREATE TABLE statement
286+
* @param fields list of vector record fields.
287+
* @return comma separated list of columns and types for CREATE TABLE statement.
288+
*/
289+
public static String getVectorColumnNamesAndTypes(List<VectorStoreRecordVectorField> fields) {
290+
List<String> columns = fields.stream()
291+
.map(field -> field.getEffectiveStorageName() + " " +
292+
OracleVectorStoreFieldHelper.getTypeForVectorField(field)
293+
).collect(Collectors.toList());
294+
295+
return String.join(", ", columns);
296+
}
297+
298+
/**
299+
* Returns key column names and type for key column for CREATE TABLE statement
300+
* @param field the key field.
301+
* @return column name and type of the key field for CREATE TABLE statement.
302+
*/
303+
public static String getKeyColumnNameAndType(VectorStoreRecordKeyField field) {
304+
return field.getEffectiveStorageName() + " " + supportedKeyTypes.get(field.getFieldType());
305+
}
306+
307+
}

data/semantickernel-data-oracle/src/test/java/com/microsoft/semantickernel/data/jdbc/oracle/Hotel.java

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -40,15 +40,15 @@ public class Hotel {
4040

4141
@JsonProperty("summaryEmbedding2")
4242
@VectorStoreRecordVector(dimensions = 8, distanceFunction = DistanceFunction.COSINE_DISTANCE, indexKind = IndexKind.HNSW)
43-
private final List<Float> cosineDistance;
43+
private final float[] cosineDistance;
4444

4545
@JsonProperty("summaryEmbedding3")
4646
@VectorStoreRecordVector(dimensions = 8, distanceFunction = DistanceFunction.COSINE_SIMILARITY, indexKind = IndexKind.IVFFLAT)
47-
private final List<Float> cosineSimilarity;
47+
private final float[] cosineSimilarity;
4848

4949
@JsonProperty("summaryEmbedding4")
5050
@VectorStoreRecordVector(dimensions = 8, distanceFunction = DistanceFunction.DOT_PRODUCT, indexKind = IndexKind.IVFFLAT)
51-
private final List<Float> dotProduct;
51+
private final Float[] dotProduct;
5252
@VectorStoreRecordData
5353
private double rating;
5454

@@ -66,9 +66,9 @@ protected Hotel(
6666
@JsonProperty("tags") List<String> tags,
6767
@JsonProperty("summary") String description,
6868
@JsonProperty("summaryEmbedding1") List<Float> euclidean,
69-
@JsonProperty("summaryEmbedding2") List<Float> cosineDistance,
70-
@JsonProperty("summaryEmbedding3") List<Float> cosineSimilarity,
71-
@JsonProperty("summaryEmbedding4") List<Float> dotProduct,
69+
@JsonProperty("summaryEmbedding2") float[] cosineDistance,
70+
@JsonProperty("summaryEmbedding3") float[] cosineSimilarity,
71+
@JsonProperty("summaryEmbedding4") Float[] dotProduct,
7272
@JsonProperty("rating") double rating) {
7373
this.id = id;
7474
this.name = name;
@@ -77,9 +77,9 @@ protected Hotel(
7777
this.tags = tags;
7878
this.description = description;
7979
this.euclidean = euclidean;
80-
this.cosineDistance = euclidean;
81-
this.cosineSimilarity = euclidean;
82-
this.dotProduct = euclidean;
80+
this.cosineDistance = cosineDistance;
81+
this.cosineSimilarity = cosineSimilarity;
82+
this.dotProduct = dotProduct;
8383
this.rating = rating;
8484
}
8585

@@ -107,11 +107,11 @@ public List<Float> getEuclidean() {
107107
return euclidean;
108108
}
109109

110-
public List<Float> getCosineDistance() {
110+
public float[] getCosineDistance() {
111111
return cosineDistance;
112112
}
113113

114-
public List<Float> getDotProduct() {
114+
public Float[] getDotProduct() {
115115
return dotProduct;
116116
}
117117

0 commit comments

Comments
 (0)