|
| 1 | +package com.microsoft.semantickernel.data.jdbc.oracle; |
| 2 | + |
| 3 | +import com.microsoft.semantickernel.data.jdbc.oracle.OracleVectorStoreQueryProvider.StringTypeMapping; |
| 4 | +import com.microsoft.semantickernel.data.vectorstorage.definition.VectorStoreRecordDataField; |
| 5 | +import com.microsoft.semantickernel.data.vectorstorage.definition.VectorStoreRecordKeyField; |
| 6 | +import com.microsoft.semantickernel.data.vectorstorage.definition.VectorStoreRecordVectorField; |
| 7 | +import oracle.jdbc.OracleTypes; |
| 8 | +import java.math.BigDecimal; |
| 9 | +import java.time.OffsetDateTime; |
| 10 | +import java.util.Collection; |
| 11 | +import java.util.HashMap; |
| 12 | +import java.util.List; |
| 13 | +import java.util.Map; |
| 14 | +import java.util.UUID; |
| 15 | +import java.util.logging.Logger; |
| 16 | +import java.util.stream.Collectors; |
| 17 | + |
| 18 | +/** |
| 19 | + * Helper class for field operations. |
| 20 | + */ |
| 21 | +public class OracleVectorStoreFieldHelper { |
| 22 | + private static final Logger LOGGER = Logger.getLogger(OracleVectorStoreQueryProvider.class.getName()); |
| 23 | + |
| 24 | + /** |
| 25 | + * Maps supported key java classes to Oracle database types |
| 26 | + */ |
| 27 | + private static final HashMap<Class<?>, String> supportedKeyTypes = new HashMap() { |
| 28 | + { |
| 29 | + put(String.class, String.format(OracleDataTypesMapping.STRING_VARCHAR, 255)); |
| 30 | + put(short.class, OracleDataTypesMapping.SHORT); |
| 31 | + put(Short.class, OracleDataTypesMapping.SHORT); |
| 32 | + put(int.class, OracleDataTypesMapping.INTEGER); |
| 33 | + put(Integer.class, OracleDataTypesMapping.INTEGER); |
| 34 | + put(long.class, OracleDataTypesMapping.LONG); |
| 35 | + put(Long.class, OracleDataTypesMapping.LONG); |
| 36 | + put(UUID .class, OracleDataTypesMapping.UUID); |
| 37 | + } |
| 38 | + }; |
| 39 | + |
| 40 | + /** |
| 41 | + * Maps supported vector java classes to Oracle database types |
| 42 | + */ |
| 43 | + private static final Map<Class<?>, String> supportedVectorTypes = new HashMap() { |
| 44 | + { |
| 45 | + put(String.class, OracleDataTypesMapping.VECTOR_FLOAT); |
| 46 | + put(List.class, OracleDataTypesMapping.VECTOR_FLOAT); |
| 47 | + put(Collection.class, OracleDataTypesMapping.VECTOR_FLOAT); |
| 48 | + put(float[].class, OracleDataTypesMapping.VECTOR_FLOAT); |
| 49 | + put(Float[].class, OracleDataTypesMapping.VECTOR_FLOAT); |
| 50 | +/* |
| 51 | + put(byte[].class,"VECTOR(%s, INT8)"); |
| 52 | + put(Byte[].class,"VECTOR(%s, INT8)"); |
| 53 | + put(double[].class,"VECTOR(%s, FLOAT64)"); |
| 54 | + put(Double[].class,"VECTOR(%s, FLOAT64)"); |
| 55 | + put(boolean[].class,"VECTOR(%s, BINARY)"); |
| 56 | + put(Boolean[].class,"VECTOR(%s, BINARY)"); |
| 57 | + */ |
| 58 | + } |
| 59 | + }; |
| 60 | + |
| 61 | + /** |
| 62 | + * Maps supported data java classes to Oracle database types |
| 63 | + */ |
| 64 | + private static final HashMap<Class<?>, String> supportedDataTypes = new HashMap() { |
| 65 | + { |
| 66 | + put(byte.class, OracleDataTypesMapping.BYTE); |
| 67 | + put(Byte.class, OracleDataTypesMapping.BYTE); |
| 68 | + put(short.class, OracleDataTypesMapping.SHORT); |
| 69 | + put(Short.class, OracleDataTypesMapping.SHORT); |
| 70 | + put(int.class, OracleDataTypesMapping.INTEGER); |
| 71 | + put(Integer.class, OracleDataTypesMapping.INTEGER); |
| 72 | + put(long.class, OracleDataTypesMapping.LONG); |
| 73 | + put(Long.class, OracleDataTypesMapping.LONG); |
| 74 | + put(Float.class, OracleDataTypesMapping.FLOAT); |
| 75 | + put(float.class, OracleDataTypesMapping.FLOAT); |
| 76 | + put(Double.class, OracleDataTypesMapping.DOUBLE); |
| 77 | + put(double.class, OracleDataTypesMapping.DOUBLE); |
| 78 | + put(BigDecimal.class, OracleDataTypesMapping.DECIMAL); |
| 79 | + put(Boolean.class, OracleDataTypesMapping.BOOLEAN); |
| 80 | + put(boolean.class, OracleDataTypesMapping.BOOLEAN); |
| 81 | + put(OffsetDateTime.class, OracleDataTypesMapping.OFFSET_DATE_TIME); |
| 82 | + put(UUID.class, OracleDataTypesMapping.UUID); |
| 83 | + put(byte[].class, OracleDataTypesMapping.BYTE_ARRAY); |
| 84 | + put(List.class, OracleDataTypesMapping.JSON); |
| 85 | + } |
| 86 | + |
| 87 | + }; |
| 88 | + |
| 89 | + /** |
| 90 | + * Maps vector type to OracleTypes. Only needed if types other than FLOAT_32 are supported. |
| 91 | + */ |
| 92 | + private static final Map<Class<?>, Integer> mapOracleTypeToVector = new HashMap() { |
| 93 | + { |
| 94 | + put(float[].class, OracleTypes.VECTOR_FLOAT32); |
| 95 | + put(Float[].class, OracleTypes.VECTOR_FLOAT32); |
| 96 | +/* |
| 97 | + put(byte[].class, OracleTypes.VECTOR_INT8); |
| 98 | + put(Byte[].class, OracleTypes.VECTOR_INT8); |
| 99 | + put(Double[].class, OracleTypes.VECTOR_FLOAT64); |
| 100 | + put(double[].class, OracleTypes.VECTOR_FLOAT64); |
| 101 | + put(Boolean[].class, OracleTypes.VECTOR_BINARY); |
| 102 | + put(boolean[].class, OracleTypes.VECTOR_BINARY); |
| 103 | +*/ |
| 104 | + } |
| 105 | + }; |
| 106 | + |
| 107 | + /** |
| 108 | + * Gets the mapping between the supported Java key types and the Oracle database type. |
| 109 | + * |
| 110 | + * @return the mapping between the supported Java key types and the Oracle database type. |
| 111 | + */ |
| 112 | + public static HashMap<Class<?>, String> getSupportedKeyTypes() { |
| 113 | + return supportedKeyTypes; |
| 114 | + } |
| 115 | + |
| 116 | + /** |
| 117 | + * Gets the mapping between the supported Java data types and the Oracle database type. |
| 118 | + * |
| 119 | + * @return the mapping between the supported Java data types and the Oracle database type. |
| 120 | + */ |
| 121 | + public static Map<Class<?>, String> getSupportedDataTypes( |
| 122 | + StringTypeMapping stringTypeMapping, int defaultVarCharLength) { |
| 123 | + |
| 124 | + if (stringTypeMapping.equals(StringTypeMapping.USE_VARCHAR)) { |
| 125 | + supportedDataTypes.put(String.class, String.format(OracleDataTypesMapping.STRING_VARCHAR, defaultVarCharLength)); |
| 126 | + } else { |
| 127 | + supportedDataTypes.put(String.class, OracleDataTypesMapping.STRING_CLOB); |
| 128 | + } |
| 129 | + return supportedDataTypes; |
| 130 | + } |
| 131 | + |
| 132 | + /** |
| 133 | + * Gets the mapping between the supported Java data types and the Oracle database type. |
| 134 | + * |
| 135 | + * @return the mapping between the supported Java data types and the Oracle database type. |
| 136 | + */ |
| 137 | + public static Map<Class<?>, String> getSupportedVectorTypes() { |
| 138 | + return supportedVectorTypes; |
| 139 | + } |
| 140 | + |
| 141 | + /** |
| 142 | + * Generates the statement to create the index according to the vector field definition. |
| 143 | + * |
| 144 | + * @return the CREATE VECTOR INDEX statement to create the index according to the vector |
| 145 | + * field definition. |
| 146 | + */ |
| 147 | + public static String getCreateVectorIndexStatement(VectorStoreRecordVectorField field, String collectionTableName) { |
| 148 | + switch (field.getIndexKind()) { |
| 149 | + case IVFFLAT: |
| 150 | + return "CREATE VECTOR INDEX IF NOT EXISTS " |
| 151 | + + getIndexName(field.getEffectiveStorageName()) |
| 152 | + + " ON " |
| 153 | + + collectionTableName + "( " + field.getEffectiveStorageName() + " ) " |
| 154 | + + " ORGANIZATION NEIGHBOR PARTITIONS " |
| 155 | + + " WITH DISTANCE COSINE " |
| 156 | + + "PARAMETERS ( TYPE IVF )"; |
| 157 | + case HNSW: |
| 158 | + return "CREATE VECTOR INDEX IF NOT EXISTS " + getIndexName(field.getEffectiveStorageName()) |
| 159 | + + " ON " |
| 160 | + + collectionTableName + "( " + field.getEffectiveStorageName() + " ) " |
| 161 | + + "ORGANIZATION INMEMORY GRAPH " |
| 162 | + + "WITH DISTANCE COSINE " |
| 163 | + + "PARAMETERS (TYPE HNSW)"; |
| 164 | + case UNDEFINED: |
| 165 | + return null; |
| 166 | + default: |
| 167 | + LOGGER.warning("Unsupported index kind: " + field.getIndexKind()); |
| 168 | + return null; |
| 169 | + } |
| 170 | + } |
| 171 | + |
| 172 | + /** |
| 173 | + * Generates the statement to create the index according to the field definition. |
| 174 | + * |
| 175 | + * @return the CREATE INDEX statement to create the index according to the field definition. |
| 176 | + */ |
| 177 | + public static String createIndexForDataField(String collectionTableName, VectorStoreRecordDataField dataField, Map<Class<?>, String> supportedDataTypes) { |
| 178 | + if (supportedDataTypes.get(dataField.getFieldType()) == "JSON") { |
| 179 | + String dataFieldIndex = "CREATE MULTIVALUE INDEX %s ON %s t (t.%s.%s)"; |
| 180 | + return String.format(dataFieldIndex, |
| 181 | + collectionTableName + "_" + dataField.getEffectiveStorageName(), |
| 182 | + collectionTableName, |
| 183 | + dataField.getEffectiveStorageName(), |
| 184 | + getFunctionForType(supportedDataTypes.get(dataField.getFieldSubType()))); |
| 185 | + } else { |
| 186 | + String dataFieldIndex = "CREATE INDEX %s ON %s (%s ASC)"; |
| 187 | + return String.format(dataFieldIndex, |
| 188 | + collectionTableName + "_" + dataField.getEffectiveStorageName(), |
| 189 | + collectionTableName, |
| 190 | + dataField.getEffectiveStorageName() |
| 191 | + ); |
| 192 | + } |
| 193 | + } |
| 194 | + |
| 195 | + /** |
| 196 | + * Gets the function that allows to return the function that converts the JSON value to the |
| 197 | + * data type. |
| 198 | + * @param jdbcType The JDBC type. |
| 199 | + * @return the function that allows to return the function that converts the JSON value to the |
| 200 | + * data type. |
| 201 | + */ |
| 202 | + private static String getFunctionForType(String jdbcType) { |
| 203 | + switch (jdbcType) { |
| 204 | + case OracleDataTypesMapping.BOOLEAN: |
| 205 | + return "boolean()"; |
| 206 | + case OracleDataTypesMapping.BYTE: |
| 207 | + case OracleDataTypesMapping.SHORT: |
| 208 | + case OracleDataTypesMapping.INTEGER: |
| 209 | + case OracleDataTypesMapping.LONG: |
| 210 | + case OracleDataTypesMapping.FLOAT: |
| 211 | + case OracleDataTypesMapping.DOUBLE: |
| 212 | + case OracleDataTypesMapping.DECIMAL: |
| 213 | + return "numberOnly()"; |
| 214 | + case OracleDataTypesMapping.OFFSET_DATE_TIME: |
| 215 | + return "timestamp()"; |
| 216 | + default: |
| 217 | + return "string()"; |
| 218 | + } |
| 219 | + } |
| 220 | + |
| 221 | + /** |
| 222 | + * Gets the type of the vector given the field definition. This method is not needed if only |
| 223 | + * |
| 224 | + * @param field the vector field definition. |
| 225 | + * @return returns the type of vector for the given field type. |
| 226 | + */ |
| 227 | + public static String getTypeForVectorField(VectorStoreRecordVectorField field) { |
| 228 | + String dimension = field.getDimensions() > 0 ? String.valueOf(field.getDimensions()) : "*"; |
| 229 | + return String.format(supportedVectorTypes.get(field.getFieldType()), dimension); |
| 230 | +/* Not needed since all types are FLOAT32 |
| 231 | + if (field.getFieldSubType() != null) { |
| 232 | + String vectorType; |
| 233 | + switch (field.getFieldSubType().getName()) { |
| 234 | + case "java.lang.Double": |
| 235 | + vectorType = "FLOAT64"; |
| 236 | + break; |
| 237 | + case "java.lang.Byte": |
| 238 | + vectorType = "INT8"; |
| 239 | + break; |
| 240 | + case "java.lang.Boolean": |
| 241 | + vectorType = "BINARY"; |
| 242 | + break; |
| 243 | + default: |
| 244 | + vectorType = "FLOAT32"; |
| 245 | + } |
| 246 | + return String.format(supportedVectorTypes.get(field.getFieldType()), dimension, vectorType); |
| 247 | + } else { |
| 248 | + return String.format(supportedVectorTypes.get(field.getFieldType()), dimension); |
| 249 | + } |
| 250 | + */ |
| 251 | + } |
| 252 | + |
| 253 | + /** |
| 254 | + * Gets the JDBC oracle of the vector field definition. |
| 255 | + * @param field the vector field definition. |
| 256 | + * @return the JDBC oracle type. |
| 257 | + */ |
| 258 | + public static int getOracleTypeForField(VectorStoreRecordVectorField field) { |
| 259 | + if (field.getFieldSubType() == null) { |
| 260 | + return mapOracleTypeToVector.get(field.getFieldType()).intValue(); |
| 261 | + } else { |
| 262 | + switch (field.getFieldSubType().getName()) { |
| 263 | + case "java.lang.Double": |
| 264 | + return OracleTypes.VECTOR_FLOAT64; |
| 265 | + case "java.lang.Byte": |
| 266 | + return OracleTypes.VECTOR_INT8; |
| 267 | + case "java.lang.Boolean": |
| 268 | + return OracleTypes.VECTOR_BINARY; |
| 269 | + default: |
| 270 | + return OracleTypes.VECTOR_FLOAT32; |
| 271 | + } |
| 272 | + } |
| 273 | + } |
| 274 | + |
| 275 | + /** |
| 276 | + * Generates the index name given the field name. by suffixing "_VECTOR_INDEX" to the field name. |
| 277 | + * @param effectiveStorageName the field name. |
| 278 | + * @return the index name. |
| 279 | + */ |
| 280 | + private static String getIndexName(String effectiveStorageName) { |
| 281 | + return effectiveStorageName + "_VECTOR_INDEX"; |
| 282 | + } |
| 283 | + |
| 284 | + /** |
| 285 | + * Returns vector columns names and types for CREATE TABLE statement |
| 286 | + * @param fields list of vector record fields. |
| 287 | + * @return comma separated list of columns and types for CREATE TABLE statement. |
| 288 | + */ |
| 289 | + public static String getVectorColumnNamesAndTypes(List<VectorStoreRecordVectorField> fields) { |
| 290 | + List<String> columns = fields.stream() |
| 291 | + .map(field -> field.getEffectiveStorageName() + " " + |
| 292 | + OracleVectorStoreFieldHelper.getTypeForVectorField(field) |
| 293 | + ).collect(Collectors.toList()); |
| 294 | + |
| 295 | + return String.join(", ", columns); |
| 296 | + } |
| 297 | + |
| 298 | + /** |
| 299 | + * Returns key column names and type for key column for CREATE TABLE statement |
| 300 | + * @param field the key field. |
| 301 | + * @return column name and type of the key field for CREATE TABLE statement. |
| 302 | + */ |
| 303 | + public static String getKeyColumnNameAndType(VectorStoreRecordKeyField field) { |
| 304 | + return field.getEffectiveStorageName() + " " + supportedKeyTypes.get(field.getFieldType()); |
| 305 | + } |
| 306 | + |
| 307 | +} |
0 commit comments