Skip to content

Commit 5d9448a

Browse files
author
Milder Hernandez Cagua
committed
Add Postgres vector index support
1 parent 5642738 commit 5d9448a

8 files changed

Lines changed: 150 additions & 17 deletions

File tree

api-test/integration-tests/src/test/java/com/microsoft/semantickernel/tests/connectors/memory/Hotel.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ public class Hotel {
1717
private final String description;
1818
@VectorStoreRecordVectorAttribute(dimensions = 3)
1919
private final List<Float> descriptionEmbedding;
20+
@VectorStoreRecordVectorAttribute(dimensions = 3, indexKind = "hnsw", distanceFunction = "cosine")
21+
private final List<Float> additionalEmbedding;
2022
@VectorStoreRecordDataAttribute
2123
private double rating;
2224

@@ -30,6 +32,7 @@ public Hotel(String id, String name, int code, String description, List<Float> d
3032
this.code = code;
3133
this.description = description;
3234
this.descriptionEmbedding = descriptionEmbedding;
35+
this.additionalEmbedding = descriptionEmbedding;
3336
this.rating = rating;
3437
}
3538

@@ -52,6 +55,9 @@ public String getDescription() {
5255
public List<Float> getDescriptionEmbedding() {
5356
return descriptionEmbedding;
5457
}
58+
public List<Float> getAdditionalEmbedding() {
59+
return additionalEmbedding;
60+
}
5561

5662
public double getRating() {
5763
return rating;

semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/azureaisearch/AzureAISearchVectorStoreCollectionCreateMapping.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ private static VectorSearchAlgorithmMetric getAlgorithmMetric(
3434
}
3535

3636
switch (vectorField.getDistanceFunction()) {
37-
case COSINE_SIMILARITY:
37+
case COSINE:
3838
return VectorSearchAlgorithmMetric.COSINE;
3939
case DOT_PRODUCT:
4040
return VectorSearchAlgorithmMetric.DOT_PRODUCT;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
// Copyright (c) Microsoft. All rights reserved.
2+
package com.microsoft.semantickernel.connectors.data.postgres;
3+
4+
import com.microsoft.semantickernel.data.recorddefinition.DistanceFunction;
5+
6+
public enum PostgreSQLVectorDistanceFunction {
7+
L2("vector_l2_ops", "<->"), COSINE("vector_cosine_ops", "<=>"), INNER_PRODUCT("vector_ip_ops",
8+
"<#>");
9+
10+
private final String value;
11+
private final String operator;
12+
13+
PostgreSQLVectorDistanceFunction(String value, String operator) {
14+
this.value = value;
15+
this.operator = operator;
16+
}
17+
18+
public String getValue() {
19+
return value;
20+
}
21+
22+
public String getOperator() {
23+
return operator;
24+
}
25+
26+
public static PostgreSQLVectorDistanceFunction fromDistanceFunction(DistanceFunction function) {
27+
if (function == null) {
28+
return null;
29+
}
30+
31+
switch (function) {
32+
case EUCLIDEAN:
33+
return L2;
34+
case COSINE:
35+
return COSINE;
36+
case DOT_PRODUCT:
37+
return INNER_PRODUCT;
38+
default:
39+
throw new IllegalArgumentException("Unsupported distance function: " + function);
40+
}
41+
}
42+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// Copyright (c) Microsoft. All rights reserved.
2+
package com.microsoft.semantickernel.connectors.data.postgres;
3+
4+
import com.microsoft.semantickernel.data.recorddefinition.IndexKind;
5+
6+
public enum PostgreSQLVectorIndexKind {
7+
HNSW("hnsw"), IVFFLAT("ivfflat");
8+
9+
private final String value;
10+
11+
PostgreSQLVectorIndexKind(String value) {
12+
this.value = value;
13+
}
14+
15+
public String getValue() {
16+
return value;
17+
}
18+
19+
public static PostgreSQLVectorIndexKind fromIndexKind(IndexKind indexKind) {
20+
if (indexKind == null) {
21+
return null;
22+
}
23+
24+
switch (indexKind) {
25+
case HNSW:
26+
return HNSW;
27+
case FLAT:
28+
return IVFFLAT;
29+
default:
30+
throw new IllegalArgumentException("Unsupported index kind: " + indexKind);
31+
}
32+
}
33+
}

semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/postgres/PostgreSQLVectorStoreQueryProvider.java

Lines changed: 53 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import com.fasterxml.jackson.databind.ObjectMapper;
66
import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreDefaultQueryProvider;
77
import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreQueryProvider;
8+
import com.microsoft.semantickernel.data.recorddefinition.DistanceFunction;
9+
import com.microsoft.semantickernel.data.recorddefinition.IndexKind;
810
import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordDefinition;
911
import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordField;
1012
import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordKeyField;
@@ -19,6 +21,7 @@
1921
import java.sql.PreparedStatement;
2022
import java.sql.ResultSet;
2123
import java.sql.SQLException;
24+
import java.sql.Statement;
2225
import java.time.OffsetDateTime;
2326
import java.util.ArrayList;
2427
import java.util.Collection;
@@ -155,6 +158,29 @@ private String getColumnNamesAndTypesForVectorFields(List<VectorStoreRecordVecto
155158
return columnNames.toString();
156159
}
157160

161+
private String createIndexForVectorField(String collectionName,
162+
VectorStoreRecordVectorField vectorField) {
163+
PostgreSQLVectorIndexKind indexKind = PostgreSQLVectorIndexKind
164+
.fromIndexKind(vectorField.getIndexKind());
165+
PostgreSQLVectorDistanceFunction distanceFunction = PostgreSQLVectorDistanceFunction
166+
.fromDistanceFunction(vectorField.getDistanceFunction());
167+
168+
// If indexKind is not specified, no index is created
169+
// and pgvector performs exact nearest neighbor search.
170+
if (indexKind == null) {
171+
return null;
172+
}
173+
if (distanceFunction == null) {
174+
throw new SKException(
175+
"Distance function is required for vector field: " + vectorField.getName());
176+
}
177+
178+
return "CREATE INDEX IF NOT EXISTS " + getCollectionTableName(collectionName) + "_index"
179+
+ " ON " + getCollectionTableName(collectionName)
180+
+ " USING " + indexKind.getValue()
181+
+ " (" + vectorField.getName() + " " + distanceFunction.getValue() + ");";
182+
}
183+
158184
/**
159185
* Creates a collection.
160186
*
@@ -164,29 +190,44 @@ private String getColumnNamesAndTypesForVectorFields(List<VectorStoreRecordVecto
164190
* @throws SKException if an error occurs while creating the collection
165191
*/
166192
@Override
167-
@SuppressFBWarnings("SQL_PREPARED_STATEMENT_GENERATED_FROM_NONCONSTANT_STRING") // SQL query is generated dynamically with valid identifiers
193+
@SuppressFBWarnings(value = {
194+
"SQL_NONCONSTANT_STRING_PASSED_TO_EXECUTE",
195+
"SQL_PREPARED_STATEMENT_GENERATED_FROM_NONCONSTANT_STRING"
196+
}) // SQL query is generated dynamically with valid identifiers
168197
public void createCollection(String collectionName, Class<?> recordClass,
169198
VectorStoreRecordDefinition recordDefinition) {
170199
Field keyDeclaredField = recordDefinition.getKeyDeclaredField(recordClass);
171200
List<Field> dataDeclaredFields = recordDefinition.getDataDeclaredFields(recordClass);
172201

173-
String createStorageTable = "CREATE TABLE IF NOT EXISTS "
174-
+ getCollectionTableName(collectionName)
175-
+ " (" + keyDeclaredField.getName() + " VARCHAR(255) PRIMARY KEY, "
176-
+ getColumnNamesAndTypes(dataDeclaredFields, supportedDataTypes) + ", "
177-
+ getColumnNamesAndTypesForVectorFields(recordDefinition.getVectorFields(), recordClass)
178-
+ ");";
179-
180-
String insertCollectionQuery = "INSERT INTO " + validateSQLidentifier(collectionsTable)
181-
+ " (collectionId) VALUES (?)";
202+
List<VectorStoreRecordVectorField> vectorFields = recordDefinition.getVectorFields();
182203

183204
try (Connection connection = dataSource.getConnection();
184-
PreparedStatement createTable = connection.prepareStatement(createStorageTable)) {
185-
createTable.execute();
205+
Statement createTableAndIndexes = connection.createStatement()) {
206+
207+
String createStorageTable = "CREATE TABLE IF NOT EXISTS "
208+
+ getCollectionTableName(collectionName)
209+
+ " (" + keyDeclaredField.getName() + " VARCHAR(255) PRIMARY KEY, "
210+
+ getColumnNamesAndTypes(dataDeclaredFields, supportedDataTypes) + ", "
211+
+ getColumnNamesAndTypesForVectorFields(vectorFields, recordClass)
212+
+ ");";
213+
214+
createTableAndIndexes.addBatch(createStorageTable);
215+
for (VectorStoreRecordVectorField vectorField : vectorFields) {
216+
String createVectorIndex = createIndexForVectorField(collectionName, vectorField);
217+
218+
if (createVectorIndex != null) {
219+
createTableAndIndexes.addBatch(createVectorIndex);
220+
}
221+
}
222+
223+
createTableAndIndexes.executeBatch();
186224
} catch (SQLException e) {
187225
throw new SKException("Failed to create collection", e);
188226
}
189227

228+
String insertCollectionQuery = "INSERT INTO " + validateSQLidentifier(collectionsTable)
229+
+ " (collectionId) VALUES (?)";
230+
190231
try (Connection connection = dataSource.getConnection();
191232
PreparedStatement insert = connection.prepareStatement(insertCollectionQuery)) {
192233
insert.setObject(1, collectionName);

semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/redis/RedisVectorStoreCollectionCreateMapping.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ private static String getAlgorithmMetric(
3535
}
3636

3737
switch (vectorField.getDistanceFunction()) {
38-
case COSINE_SIMILARITY:
38+
case COSINE:
3939
return RedisVectorDistanceMetric.COSINE;
4040
case DOT_PRODUCT:
4141
return RedisVectorDistanceMetric.DOT_PRODUCT;

semantickernel-experimental/src/main/java/com/microsoft/semantickernel/data/recorddefinition/DistanceFunction.java

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,18 @@
22
package com.microsoft.semantickernel.data.recorddefinition;
33

44
public enum DistanceFunction {
5-
COSINE_SIMILARITY("cosineSimilarity"), DOT_PRODUCT("dotProduct"), EUCLIDEAN("euclidean");
5+
/**
6+
* Cosine (angular) similarity function.
7+
*/
8+
COSINE("cosine"),
9+
/**
10+
* Dot product between two vectors.
11+
*/
12+
DOT_PRODUCT("dotProduct"),
13+
/**
14+
* Euclidean distance function. Also known as L2 norm.
15+
*/
16+
EUCLIDEAN("euclidean");
617

718
private final String value;
819

@@ -23,7 +34,7 @@ public String getValue() {
2334
*/
2435
public static DistanceFunction fromString(String text) {
2536
if (text == null || text.isEmpty()) {
26-
return COSINE_SIMILARITY;
37+
return null;
2738
}
2839

2940
for (DistanceFunction b : DistanceFunction.values()) {

semantickernel-experimental/src/main/java/com/microsoft/semantickernel/data/recorddefinition/IndexKind.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ public String getValue() {
2323
*/
2424
public static IndexKind fromString(String text) {
2525
if (text == null || text.isEmpty()) {
26-
return FLAT;
26+
return null;
2727
}
2828

2929
for (IndexKind b : IndexKind.values()) {

0 commit comments

Comments
 (0)