-
Notifications
You must be signed in to change notification settings - Fork 111
Expand file tree
/
Copy pathAzureOpenAITextEmbeddingGenerator.cs
More file actions
123 lines (106 loc) · 4.67 KB
/
AzureOpenAITextEmbeddingGenerator.cs
File metadata and controls
123 lines (106 loc) · 4.67 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
// Copyright (c) Microsoft. All rights reserved.
using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;
using Azure.Identity;
using Helpers;
using Microsoft.Extensions.Logging;
using Microsoft.KernelMemory.AI.OpenAI;
using Microsoft.KernelMemory.Diagnostics;
using Microsoft.SemanticKernel.AI.Embeddings;
using Microsoft.SemanticKernel.Connectors.OpenAI;
namespace Microsoft.KernelMemory.AI.AzureOpenAI;
[Experimental("KMEXP01")]
public sealed class AzureOpenAITextEmbeddingGenerator : ITextEmbeddingGenerator, ITextEmbeddingBatchGenerator
{
private readonly ITextTokenizer _textTokenizer;
private readonly AzureOpenAITextEmbeddingGenerationService _client;
private readonly ILogger<AzureOpenAITextEmbeddingGenerator> _log;
private readonly string _deployment;
public AzureOpenAITextEmbeddingGenerator(
AzureOpenAIConfig config,
ITextTokenizer? textTokenizer = null,
ILoggerFactory? loggerFactory = null,
HttpClient? httpClient = null)
{
this._log = (loggerFactory ?? DefaultLogger.Factory).CreateLogger<AzureOpenAITextEmbeddingGenerator>();
if (textTokenizer == null)
{
this._log.LogWarning(
"Tokenizer not specified, will use {0}. The token count might be incorrect, causing unexpected errors",
nameof(GPT4Tokenizer));
textTokenizer = new GPT4Tokenizer();
}
this._textTokenizer = textTokenizer;
this._deployment = config.Deployment;
this.MaxTokens = config.MaxTokenTotal;
this.MaxBatchSize = config.MaxEmbeddingBatchSize;
switch (config.Auth)
{
case AzureOpenAIConfig.AuthTypes.AzureIdentity:
this._client = new AzureOpenAITextEmbeddingGenerationService(
deploymentName: config.Deployment,
endpoint: config.Endpoint,
credential: azure_credential_utils.GetAzureCredential(config.APP_ENV),
modelId: config.Deployment,
httpClient: httpClient,
dimensions: config.EmbeddingDimensions,
loggerFactory: loggerFactory);
break;
case AzureOpenAIConfig.AuthTypes.ManualTokenCredential:
this._client = new AzureOpenAITextEmbeddingGenerationService(
deploymentName: config.Deployment,
endpoint: config.Endpoint,
credential: config.GetTokenCredential(),
modelId: config.Deployment,
httpClient: httpClient,
dimensions: config.EmbeddingDimensions,
loggerFactory: loggerFactory);
break;
case AzureOpenAIConfig.AuthTypes.APIKey:
this._client = new AzureOpenAITextEmbeddingGenerationService(
deploymentName: config.Deployment,
endpoint: config.Endpoint,
apiKey: config.APIKey,
modelId: config.Deployment,
httpClient: httpClient,
dimensions: config.EmbeddingDimensions,
loggerFactory: loggerFactory);
break;
default:
throw new NotImplementedException($"Azure OpenAI auth type '{config.Auth}' not available");
}
}
/// <inheritdoc/>
public int MaxTokens { get; }
/// <inheritdoc/>
public int MaxBatchSize { get; }
/// <inheritdoc/>
public int CountTokens(string text)
{
return this._textTokenizer.CountTokens(text);
}
/// <inheritdoc/>
public IReadOnlyList<string> GetTokens(string text)
{
return this._textTokenizer.GetTokens(text);
}
/// <inheritdoc/>
public Task<Embedding> GenerateEmbeddingAsync(string text, CancellationToken cancellationToken = default)
{
this._log.LogTrace("Generating embedding, deployment '{0}'", this._deployment);
return this._client.GenerateEmbeddingAsync(text, cancellationToken);
}
/// <inheritdoc/>
public async Task<Embedding[]> GenerateEmbeddingBatchAsync(IEnumerable<string> textList, CancellationToken cancellationToken = default)
{
var list = textList.ToList();
this._log.LogTrace("Generating embeddings, deployment '{0}', batch size '{1}'", this._deployment, list.Count);
IList<ReadOnlyMemory<float>> embeddings = await this._client.GenerateEmbeddingsAsync(list, cancellationToken: cancellationToken).ConfigureAwait(false);
return embeddings.Select(e => new Embedding(e)).ToArray();
}
}