Skip to content

Commit 5955e31

Browse files
author
Milder Hernandez
authored
Merge pull request #178 from milderhc/postgres-vector-index
Add Postgres vector index support
2 parents 2f7d845 + 99e67c8 commit 5955e31

10 files changed

Lines changed: 152 additions & 20 deletions

File tree

.github/_typos.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ ans = "ans" # Short for answers
2929
arange = "arange" # Method in Python numpy package
3030
prompty = "prompty" # prompty is a format name.
3131
ist = "ist" # German language
32+
Prelease = "Prelease" # Prelease is a format name.
3233

3334
[default.extend-identifiers]
3435
ags = "ags" # Azure Graph Service

.github/workflows/java-build.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ jobs:
5656
run: ./mvnw -B -Pbug-check -Pcompile-jdk${{ matrix.java-versions }} test --file pom.xml
5757

5858
# Uploads test artifacts for each JDK version
59-
- uses: actions/upload-artifact@v2
59+
- uses: actions/upload-artifact@v4
6060
if: always()
6161
with:
6262
name: test_output_sk_jdk${{ matrix.java-versions }}u

api-test/integration-tests/src/test/java/com/microsoft/semantickernel/tests/connectors/memory/Hotel.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ public class Hotel {
2121
@JsonProperty("summaryEmbedding")
2222
@VectorStoreRecordVectorAttribute(dimensions = 3)
2323
private final List<Float> descriptionEmbedding;
24+
@VectorStoreRecordVectorAttribute(dimensions = 3, indexKind = "hnsw", distanceFunction = "cosine")
25+
private final List<Float> additionalEmbedding;
2426
@VectorStoreRecordDataAttribute
2527
private double rating;
2628

@@ -41,6 +43,7 @@ public Hotel(
4143
this.code = code;
4244
this.description = description;
4345
this.descriptionEmbedding = descriptionEmbedding;
46+
this.additionalEmbedding = descriptionEmbedding;
4447
this.rating = rating;
4548
}
4649

@@ -63,6 +66,9 @@ public String getDescription() {
6366
public List<Float> getDescriptionEmbedding() {
6467
return descriptionEmbedding;
6568
}
69+
public List<Float> getAdditionalEmbedding() {
70+
return additionalEmbedding;
71+
}
6672

6773
public double getRating() {
6874
return rating;

semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/azureaisearch/AzureAISearchVectorStoreCollectionCreateMapping.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ private static VectorSearchAlgorithmMetric getAlgorithmMetric(
3636
}
3737

3838
switch (vectorField.getDistanceFunction()) {
39-
case COSINE_SIMILARITY:
39+
case COSINE:
4040
return VectorSearchAlgorithmMetric.COSINE;
4141
case DOT_PRODUCT:
4242
return VectorSearchAlgorithmMetric.DOT_PRODUCT;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
// Copyright (c) Microsoft. All rights reserved.
2+
package com.microsoft.semantickernel.connectors.data.postgres;
3+
4+
import com.microsoft.semantickernel.data.recorddefinition.DistanceFunction;
5+
6+
public enum PostgreSQLVectorDistanceFunction {
7+
L2("vector_l2_ops", "<->"), COSINE("vector_cosine_ops", "<=>"), INNER_PRODUCT("vector_ip_ops",
8+
"<#>");
9+
10+
private final String value;
11+
private final String operator;
12+
13+
PostgreSQLVectorDistanceFunction(String value, String operator) {
14+
this.value = value;
15+
this.operator = operator;
16+
}
17+
18+
public String getValue() {
19+
return value;
20+
}
21+
22+
public String getOperator() {
23+
return operator;
24+
}
25+
26+
public static PostgreSQLVectorDistanceFunction fromDistanceFunction(DistanceFunction function) {
27+
if (function == null) {
28+
return null;
29+
}
30+
31+
switch (function) {
32+
case EUCLIDEAN:
33+
return L2;
34+
case COSINE:
35+
return COSINE;
36+
case DOT_PRODUCT:
37+
return INNER_PRODUCT;
38+
default:
39+
throw new IllegalArgumentException("Unsupported distance function: " + function);
40+
}
41+
}
42+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// Copyright (c) Microsoft. All rights reserved.
2+
package com.microsoft.semantickernel.connectors.data.postgres;
3+
4+
import com.microsoft.semantickernel.data.recorddefinition.IndexKind;
5+
6+
public enum PostgreSQLVectorIndexKind {
7+
HNSW("hnsw"), IVFFLAT("ivfflat");
8+
9+
private final String value;
10+
11+
PostgreSQLVectorIndexKind(String value) {
12+
this.value = value;
13+
}
14+
15+
public String getValue() {
16+
return value;
17+
}
18+
19+
public static PostgreSQLVectorIndexKind fromIndexKind(IndexKind indexKind) {
20+
if (indexKind == null) {
21+
return null;
22+
}
23+
24+
switch (indexKind) {
25+
case HNSW:
26+
return HNSW;
27+
case FLAT:
28+
return IVFFLAT;
29+
default:
30+
throw new IllegalArgumentException("Unsupported index kind: " + indexKind);
31+
}
32+
}
33+
}

semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/postgres/PostgreSQLVectorStoreQueryProvider.java

Lines changed: 53 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import java.sql.Connection;
2020
import java.sql.PreparedStatement;
2121
import java.sql.SQLException;
22+
import java.sql.Statement;
2223
import java.time.OffsetDateTime;
2324
import java.util.ArrayList;
2425
import java.util.Collection;
@@ -150,6 +151,29 @@ private String getColumnNamesAndTypesForVectorFields(
150151
.collect(Collectors.joining(", "));
151152
}
152153

154+
private String createIndexForVectorField(String collectionName,
155+
VectorStoreRecordVectorField vectorField) {
156+
PostgreSQLVectorIndexKind indexKind = PostgreSQLVectorIndexKind
157+
.fromIndexKind(vectorField.getIndexKind());
158+
PostgreSQLVectorDistanceFunction distanceFunction = PostgreSQLVectorDistanceFunction
159+
.fromDistanceFunction(vectorField.getDistanceFunction());
160+
161+
// If indexKind is not specified, no index is created
162+
// and pgvector performs exact nearest neighbor search.
163+
if (indexKind == null) {
164+
return null;
165+
}
166+
if (distanceFunction == null) {
167+
throw new SKException(
168+
"Distance function is required for vector field: " + vectorField.getName());
169+
}
170+
171+
return "CREATE INDEX IF NOT EXISTS " + getCollectionTableName(collectionName) + "_index"
172+
+ " ON " + getCollectionTableName(collectionName)
173+
+ " USING " + indexKind.getValue()
174+
+ " (" + vectorField.getName() + " " + distanceFunction.getValue() + ");";
175+
}
176+
153177
/**
154178
* Creates a collection.
155179
*
@@ -158,29 +182,44 @@ private String getColumnNamesAndTypesForVectorFields(
158182
* @throws SKException if an error occurs while creating the collection
159183
*/
160184
@Override
161-
@SuppressFBWarnings("SQL_PREPARED_STATEMENT_GENERATED_FROM_NONCONSTANT_STRING") // SQL query is generated dynamically with valid identifiers
185+
@SuppressFBWarnings(value = {
186+
"SQL_NONCONSTANT_STRING_PASSED_TO_EXECUTE",
187+
"SQL_PREPARED_STATEMENT_GENERATED_FROM_NONCONSTANT_STRING"
188+
}) // SQL query is generated dynamically with valid identifiers
162189
public void createCollection(String collectionName,
163190
VectorStoreRecordDefinition recordDefinition) {
164191

165-
String createStorageTable = "CREATE TABLE IF NOT EXISTS "
166-
+ getCollectionTableName(collectionName) + " ("
167-
+ getKeyColumnName(recordDefinition.getKeyField()) + " VARCHAR(255) PRIMARY KEY, "
168-
+ getColumnNamesAndTypes(new ArrayList<>(recordDefinition.getDataFields()),
169-
supportedDataTypes)
170-
+ ", "
171-
+ getColumnNamesAndTypesForVectorFields(recordDefinition.getVectorFields())
172-
+ ");";
173-
174-
String insertCollectionQuery = "INSERT INTO " + validateSQLidentifier(collectionsTable)
175-
+ " (collectionId) VALUES (?)";
192+
List<VectorStoreRecordVectorField> vectorFields = recordDefinition.getVectorFields();
176193

177194
try (Connection connection = dataSource.getConnection();
178-
PreparedStatement createTable = connection.prepareStatement(createStorageTable)) {
179-
createTable.execute();
195+
Statement createTableAndIndexes = connection.createStatement()) {
196+
197+
String createStorageTable = "CREATE TABLE IF NOT EXISTS "
198+
+ getCollectionTableName(collectionName) + " ("
199+
+ getKeyColumnName(recordDefinition.getKeyField()) + " VARCHAR(255) PRIMARY KEY, "
200+
+ getColumnNamesAndTypes(new ArrayList<>(recordDefinition.getDataFields()),
201+
supportedDataTypes)
202+
+ ", "
203+
+ getColumnNamesAndTypesForVectorFields(recordDefinition.getVectorFields())
204+
+ ");";
205+
206+
createTableAndIndexes.addBatch(createStorageTable);
207+
for (VectorStoreRecordVectorField vectorField : vectorFields) {
208+
String createVectorIndex = createIndexForVectorField(collectionName, vectorField);
209+
210+
if (createVectorIndex != null) {
211+
createTableAndIndexes.addBatch(createVectorIndex);
212+
}
213+
}
214+
215+
createTableAndIndexes.executeBatch();
180216
} catch (SQLException e) {
181217
throw new SKException("Failed to create collection", e);
182218
}
183219

220+
String insertCollectionQuery = "INSERT INTO " + validateSQLidentifier(collectionsTable)
221+
+ " (collectionId) VALUES (?)";
222+
184223
try (Connection connection = dataSource.getConnection();
185224
PreparedStatement insert = connection.prepareStatement(insertCollectionQuery)) {
186225
insert.setObject(1, collectionName);

semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/redis/RedisVectorStoreCollectionCreateMapping.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ private static String getAlgorithmMetric(
3737
}
3838

3939
switch (vectorField.getDistanceFunction()) {
40-
case COSINE_SIMILARITY:
40+
case COSINE:
4141
return RedisVectorDistanceMetric.COSINE;
4242
case DOT_PRODUCT:
4343
return RedisVectorDistanceMetric.DOT_PRODUCT;

semantickernel-experimental/src/main/java/com/microsoft/semantickernel/data/recorddefinition/DistanceFunction.java

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,18 @@
22
package com.microsoft.semantickernel.data.recorddefinition;
33

44
public enum DistanceFunction {
5-
COSINE_SIMILARITY("cosineSimilarity"), DOT_PRODUCT("dotProduct"), EUCLIDEAN("euclidean");
5+
/**
6+
* Cosine (angular) similarity function.
7+
*/
8+
COSINE("cosine"),
9+
/**
10+
* Dot product between two vectors.
11+
*/
12+
DOT_PRODUCT("dotProduct"),
13+
/**
14+
* Euclidean distance function. Also known as L2 norm.
15+
*/
16+
EUCLIDEAN("euclidean");
617

718
private final String value;
819

@@ -23,7 +34,7 @@ public String getValue() {
2334
*/
2435
public static DistanceFunction fromString(String text) {
2536
if (text == null || text.isEmpty()) {
26-
return COSINE_SIMILARITY;
37+
return null;
2738
}
2839

2940
for (DistanceFunction b : DistanceFunction.values()) {

semantickernel-experimental/src/main/java/com/microsoft/semantickernel/data/recorddefinition/IndexKind.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ public String getValue() {
2323
*/
2424
public static IndexKind fromString(String text) {
2525
if (text == null || text.isEmpty()) {
26-
return FLAT;
26+
return null;
2727
}
2828

2929
for (IndexKind b : IndexKind.values()) {

0 commit comments

Comments
 (0)