Skip to content

Commit 6729109

Browse files
author
Milder Hernandez
authored
Merge pull request #209 from milderhc/sql-query-formatting
Update SQL query formatting in Vector Stores
2 parents 0fd0ce5 + bbec975 commit 6729109

3 files changed

Lines changed: 78 additions & 55 deletions

File tree

semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreDefaultQueryProvider.java

Lines changed: 43 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -166,9 +166,9 @@ public Map<Class<?>, String> getSupportedVectorTypes() {
166166
*/
167167
@Override
168168
public void prepareVectorStore() {
169-
String createCollectionsTable = "CREATE TABLE IF NOT EXISTS "
170-
+ validateSQLidentifier(collectionsTable)
171-
+ " (collectionId VARCHAR(255) PRIMARY KEY);";
169+
String createCollectionsTable = formatQuery(
170+
"CREATE TABLE IF NOT EXISTS %s (collectionId VARCHAR(255) PRIMARY KEY);",
171+
validateSQLidentifier(collectionsTable));
172172

173173
try (Connection connection = dataSource.getConnection();
174174
PreparedStatement createTable = connection.prepareStatement(createCollectionsTable)) {
@@ -207,8 +207,8 @@ public void validateSupportedTypes(VectorStoreRecordDefinition recordDefinition)
207207
*/
208208
@Override
209209
public boolean collectionExists(String collectionName) {
210-
String query = "SELECT 1 FROM " + validateSQLidentifier(collectionsTable)
211-
+ " WHERE collectionId = ?";
210+
String query = formatQuery("SELECT 1 FROM %s WHERE collectionId = ?",
211+
validateSQLidentifier(collectionsTable));
212212

213213
try (Connection connection = dataSource.getConnection();
214214
PreparedStatement statement = connection.prepareStatement(query)) {
@@ -232,18 +232,19 @@ public boolean collectionExists(String collectionName) {
232232
public void createCollection(String collectionName,
233233
VectorStoreRecordDefinition recordDefinition) {
234234

235-
String createStorageTable = "CREATE TABLE IF NOT EXISTS "
236-
+ getCollectionTableName(collectionName) + " ("
237-
+ getKeyColumnName(recordDefinition.getKeyField()) + " VARCHAR(255) PRIMARY KEY, "
238-
+ getColumnNamesAndTypes(new ArrayList<>(recordDefinition.getDataFields()),
239-
getSupportedDataTypes())
240-
+ ", "
241-
+ getColumnNamesAndTypes(new ArrayList<>(recordDefinition.getVectorFields()),
242-
getSupportedVectorTypes())
243-
+ ");";
235+
String createStorageTable = formatQuery("CREATE TABLE IF NOT EXISTS %s ("
236+
+ "%s VARCHAR(255) PRIMARY KEY, "
237+
+ "%s, "
238+
+ "%s);",
239+
getCollectionTableName(collectionName),
240+
getKeyColumnName(recordDefinition.getKeyField()),
241+
getColumnNamesAndTypes(new ArrayList<>(recordDefinition.getDataFields()),
242+
getSupportedDataTypes()),
243+
getColumnNamesAndTypes(new ArrayList<>(recordDefinition.getVectorFields()),
244+
getSupportedVectorTypes()));
244245

245-
String insertCollectionQuery = "INSERT INTO " + validateSQLidentifier(collectionsTable)
246-
+ " (collectionId) VALUES (?)";
246+
String insertCollectionQuery = formatQuery("INSERT INTO %s (collectionId) VALUES (?)",
247+
validateSQLidentifier(collectionsTable));
247248

248249
try (Connection connection = dataSource.getConnection();
249250
PreparedStatement createTable = connection.prepareStatement(createStorageTable)) {
@@ -269,9 +270,10 @@ public void createCollection(String collectionName,
269270
*/
270271
@Override
271272
public void deleteCollection(String collectionName) {
272-
String deleteCollectionOperation = "DELETE FROM " + validateSQLidentifier(collectionsTable)
273-
+ " WHERE collectionId = ?";
274-
String dropTableOperation = "DROP TABLE " + getCollectionTableName(collectionName);
273+
String deleteCollectionOperation = formatQuery("DELETE FROM %s WHERE collectionId = ?",
274+
validateSQLidentifier(collectionsTable));
275+
String dropTableOperation = formatQuery("DROP TABLE %s",
276+
getCollectionTableName(collectionName));
275277

276278
try (Connection connection = dataSource.getConnection();
277279
PreparedStatement deleteCollection = connection
@@ -298,7 +300,8 @@ public void deleteCollection(String collectionName) {
298300
*/
299301
@Override
300302
public List<String> getCollectionNames() {
301-
String query = "SELECT collectionId FROM " + validateSQLidentifier(collectionsTable);
303+
String query = formatQuery("SELECT collectionId FROM %s",
304+
validateSQLidentifier(collectionsTable));
302305

303306
try (Connection connection = dataSource.getConnection();
304307
PreparedStatement statement = connection.prepareStatement(query)) {
@@ -339,10 +342,11 @@ public <Record> List<Record> getRecords(String collectionName, List<String> keys
339342
fields = recordDefinition.getNonVectorFields();
340343
}
341344

342-
String query = "SELECT " + getQueryColumnsFromFields(fields)
343-
+ " FROM " + getCollectionTableName(collectionName)
344-
+ " WHERE " + getKeyColumnName(recordDefinition.getKeyField())
345-
+ " IN (" + getWildcardString(keys.size()) + ")";
345+
String query = formatQuery("SELECT %s FROM %s WHERE %s IN (%s)",
346+
getQueryColumnsFromFields(fields),
347+
getCollectionTableName(collectionName),
348+
getKeyColumnName(recordDefinition.getKeyField()),
349+
getWildcardString(keys.size()));
346350

347351
try (Connection connection = dataSource.getConnection();
348352
PreparedStatement statement = connection.prepareStatement(query)) {
@@ -382,9 +386,10 @@ public void upsertRecords(String collectionName, List<?> records,
382386
@Override
383387
public void deleteRecords(String collectionName, List<String> keys,
384388
VectorStoreRecordDefinition recordDefinition, DeleteRecordOptions options) {
385-
String query = "DELETE FROM " + getCollectionTableName(collectionName)
386-
+ " WHERE " + getKeyColumnName(recordDefinition.getKeyField())
387-
+ " IN (" + getWildcardString(keys.size()) + ")";
389+
String query = formatQuery("DELETE FROM %s WHERE %s IN (%s)",
390+
getCollectionTableName(collectionName),
391+
getKeyColumnName(recordDefinition.getKeyField()),
392+
getWildcardString(keys.size()));
388393

389394
try (Connection connection = dataSource.getConnection();
390395
PreparedStatement statement = connection.prepareStatement(query)) {
@@ -412,6 +417,17 @@ public static String validateSQLidentifier(String identifier) {
412417
throw new SKException("Invalid SQL identifier: " + identifier);
413418
}
414419

420+
/**
421+
* Formats a query.
422+
*
423+
* @param query the query
424+
* @param args the arguments
425+
* @return the formatted query
426+
*/
427+
public String formatQuery(String query, String... args) {
428+
return String.format(query, (Object[]) args);
429+
}
430+
415431
/**
416432
* The builder for {@link JDBCVectorStoreDefaultQueryProvider}.
417433
*/

semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/mysql/MySQLVectorStoreQueryProvider.java

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -86,14 +86,16 @@ public void upsertRecords(String collectionName, List<?> records,
8686
List<VectorStoreRecordField> fields = recordDefinition.getAllFields();
8787

8888
String onDuplicateKeyUpdate = fields.stream()
89-
.map(field -> validateSQLidentifier(field.getEffectiveStorageName())
90-
+ " = VALUES(" + validateSQLidentifier(field.getEffectiveStorageName()) + ")")
89+
.map(field -> formatQuery("%s = VALUES(%s)",
90+
validateSQLidentifier(field.getEffectiveStorageName()),
91+
field.getEffectiveStorageName()))
9192
.collect(Collectors.joining(", "));
9293

93-
String query = "INSERT INTO " + getCollectionTableName(collectionName)
94-
+ " (" + getQueryColumnsFromFields(fields) + ")"
95-
+ " VALUES (" + getWildcardString(fields.size()) + ")"
96-
+ " ON DUPLICATE KEY UPDATE " + onDuplicateKeyUpdate;
94+
String query = formatQuery("INSERT INTO %s (%s) VALUES (%s) ON DUPLICATE KEY UPDATE %s",
95+
getCollectionTableName(collectionName),
96+
getQueryColumnsFromFields(fields),
97+
getWildcardString(fields.size()),
98+
onDuplicateKeyUpdate);
9799

98100
try (Connection connection = dataSource.getConnection();
99101
PreparedStatement statement = connection.prepareStatement(query)) {

semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/postgres/PostgreSQLVectorStoreQueryProvider.java

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -168,10 +168,12 @@ private String createIndexForVectorField(String collectionName,
168168
"Distance function is required for vector field: " + vectorField.getName());
169169
}
170170

171-
return "CREATE INDEX IF NOT EXISTS " + getCollectionTableName(collectionName) + "_index"
172-
+ " ON " + getCollectionTableName(collectionName)
173-
+ " USING " + indexKind.getValue()
174-
+ " (" + vectorField.getName() + " " + distanceFunction.getValue() + ");";
171+
return formatQuery("CREATE INDEX IF NOT EXISTS %s ON %s USING %s (%s %s);",
172+
getCollectionTableName(collectionName) + "_index",
173+
getCollectionTableName(collectionName),
174+
indexKind.getValue(),
175+
vectorField.getName(),
176+
distanceFunction.getValue());
175177
}
176178

177179
/**
@@ -194,14 +196,15 @@ public void createCollection(String collectionName,
194196
try (Connection connection = dataSource.getConnection();
195197
Statement createTableAndIndexes = connection.createStatement()) {
196198

197-
String createStorageTable = "CREATE TABLE IF NOT EXISTS "
198-
+ getCollectionTableName(collectionName) + " ("
199-
+ getKeyColumnName(recordDefinition.getKeyField()) + " VARCHAR(255) PRIMARY KEY, "
200-
+ getColumnNamesAndTypes(new ArrayList<>(recordDefinition.getDataFields()),
201-
supportedDataTypes)
202-
+ ", "
203-
+ getColumnNamesAndTypesForVectorFields(recordDefinition.getVectorFields())
204-
+ ");";
199+
String createStorageTable = formatQuery("CREATE TABLE IF NOT EXISTS %s ("
200+
+ "%s VARCHAR(255) PRIMARY KEY, "
201+
+ "%s, "
202+
+ "%s);",
203+
getCollectionTableName(collectionName),
204+
getKeyColumnName(recordDefinition.getKeyField()),
205+
getColumnNamesAndTypes(new ArrayList<>(recordDefinition.getDataFields()),
206+
supportedDataTypes),
207+
getColumnNamesAndTypesForVectorFields(recordDefinition.getVectorFields()));
205208

206209
createTableAndIndexes.addBatch(createStorageTable);
207210
for (VectorStoreRecordVectorField vectorField : vectorFields) {
@@ -217,8 +220,8 @@ public void createCollection(String collectionName,
217220
throw new SKException("Failed to create collection", e);
218221
}
219222

220-
String insertCollectionQuery = "INSERT INTO " + validateSQLidentifier(collectionsTable)
221-
+ " (collectionId) VALUES (?)";
223+
String insertCollectionQuery = formatQuery("INSERT INTO %s (collectionId) VALUES (?)",
224+
validateSQLidentifier(collectionsTable));
222225

223226
try (Connection connection = dataSource.getConnection();
224227
PreparedStatement insert = connection.prepareStatement(insertCollectionQuery)) {
@@ -284,16 +287,18 @@ public void upsertRecords(String collectionName, List<?> records,
284287

285288
String onDuplicateKeyUpdate = fields.stream()
286289
.filter(field -> !(field instanceof VectorStoreRecordKeyField)) // Exclude key fields
287-
.map(field -> validateSQLidentifier(field.getEffectiveStorageName())
288-
+ " = EXCLUDED." + validateSQLidentifier(field.getEffectiveStorageName()))
290+
.map(field -> formatQuery("%s = EXCLUDED.%s",
291+
validateSQLidentifier(field.getEffectiveStorageName()),
292+
field.getEffectiveStorageName()))
289293
.collect(Collectors.joining(", "));
290294

291-
String query = "INSERT INTO " + getCollectionTableName(collectionName)
292-
+ " (" + getQueryColumnsFromFields(fields) + ")"
293-
+ " VALUES (" + getWildcardStringWithCast(fields) + ")"
294-
+ " ON CONFLICT (" + getKeyColumnName(recordDefinition.getKeyField())
295-
+ ") DO UPDATE SET "
296-
+ onDuplicateKeyUpdate;
295+
String query = formatQuery(
296+
"INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO UPDATE SET %s",
297+
getCollectionTableName(collectionName),
298+
getQueryColumnsFromFields(fields),
299+
getWildcardStringWithCast(fields),
300+
getKeyColumnName(recordDefinition.getKeyField()),
301+
onDuplicateKeyUpdate);
297302

298303
try (Connection connection = dataSource.getConnection();
299304
PreparedStatement statement = connection.prepareStatement(query)) {

0 commit comments

Comments
 (0)