Skip to content

Commit c7945f3

Browse files
fix
1 parent 4dfac7f commit c7945f3

3 files changed

Lines changed: 64 additions & 22 deletions

File tree

EssentialCSharp.Chat.Shared/Extensions/ServiceCollectionExtensions.cs

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,10 @@ public static IServiceCollection AddAzureOpenAIServices(
101101

102102
/// <summary>
103103
/// Adds PostgreSQL vector store with managed identity authentication support.
104-
/// Uses periodic token refresh to ensure tokens are renewed before expiry.
104+
/// Uses per-connection token refresh via <c>UsePasswordProvider</c>, which calls
105+
/// <see cref="TokenCredential.GetTokenAsync"/> on every new physical connection.
106+
/// <see cref="DefaultAzureCredential"/> caches tokens internally and auto-refreshes
107+
/// ~5 minutes before expiry, so this does not add Azure AD overhead.
105108
/// </summary>
106109
/// <param name="services">The service collection to add services to</param>
107110
/// <param name="connectionString">The PostgreSQL connection string (without password)</param>
@@ -133,18 +136,30 @@ private static IServiceCollection AddPostgresVectorStoreWithManagedIdentity(
133136
dataSourceBuilder.ConnectionStringBuilder.SslMode = SslMode.Require;
134137
}
135138

136-
// Use periodic token refresh instead of a one-shot token at startup.
137-
// Azure AD tokens expire after ~1 hour; refreshing every 50 minutes
138-
// ensures uninterrupted connectivity for long-running applications.
139-
dataSourceBuilder.UsePeriodicPasswordProvider(
140-
async (_, ct) =>
141-
{
142-
var tokenRequestContext = new TokenRequestContext(_PostgresScopes);
143-
var accessToken = await credential.GetTokenAsync(tokenRequestContext, ct);
144-
return accessToken.Token;
145-
},
146-
TimeSpan.FromMinutes(50),
147-
TimeSpan.FromSeconds(10));
139+
var tokenRequestContext = new TokenRequestContext(_PostgresScopes);
140+
141+
// UsePasswordProvider is called for every new physical connection.
142+
// DefaultAzureCredential caches tokens internally and auto-refreshes ~5 min before
143+
// expiry — no extra Azure AD load. This is the approach recommended by the Npgsql
144+
// docs for cloud providers that implement their own caching (Azure MI does).
145+
// UsePeriodicPasswordProvider is only for token sources without built-in caching.
146+
// See: https://www.npgsql.org/doc/security.html
147+
// See: https://github.com/npgsql/npgsql/issues/5186
148+
//
149+
// Note: The username is expected to be set in the connection string already
150+
// (Aspire sets it during deployment for Azure PostgreSQL Flexible Server).
151+
// If a standalone username-extraction fallback is ever needed, use the
152+
// Microsoft.Azure.PostgreSQL.Auth package (UseEntraAuthentication extension)
153+
// once it ships on NuGet.
154+
dataSourceBuilder.UsePasswordProvider(
155+
passwordProvider: _ => credential.GetToken(tokenRequestContext, default).Token,
156+
passwordProviderAsync: async (_, ct) =>
157+
(await credential.GetTokenAsync(tokenRequestContext, ct)).Token);
158+
159+
// Recycle pooled connections after 50 min, well before the 60-min JWT token TTL.
160+
// Combined with UsePasswordProvider (called on every new physical connection),
161+
// this ensures no pooled connection ever holds an expired token.
162+
dataSourceBuilder.ConnectionStringBuilder.ConnectionLifetime = 3000;
148163
}
149164

150165
return dataSourceBuilder.Build();

EssentialCSharp.Chat.Shared/Services/AIChatService.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,14 +118,14 @@ private async Task<string> EnrichPromptWithContext(string prompt, bool enableCon
118118
return prompt;
119119
}
120120

121-
var searchResults = await _SearchService.ExecuteVectorSearch(prompt);
121+
var searchResults = await _SearchService.ExecuteVectorSearch(prompt, cancellationToken: cancellationToken);
122122
var contextualInfo = new System.Text.StringBuilder();
123123

124124
contextualInfo.AppendLine("## Contextual Information");
125125
contextualInfo.AppendLine("The following information might be relevant to your question:");
126126
contextualInfo.AppendLine();
127127

128-
await foreach (var result in searchResults)
128+
foreach (var result in searchResults)
129129
{
130130
contextualInfo.AppendLine(System.Globalization.CultureInfo.InvariantCulture, $"**From: {result.Record.Heading}**");
131131
contextualInfo.AppendLine(result.Record.ChunkText);
Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,54 @@
1-
using EssentialCSharp.Chat.Common.Models;
1+
using System.Diagnostics;
2+
using EssentialCSharp.Chat.Common.Models;
3+
using Microsoft.Extensions.Logging;
24
using Microsoft.Extensions.VectorData;
5+
using Npgsql;
36

47
namespace EssentialCSharp.Chat.Common.Services;
58

6-
public class AISearchService(VectorStore vectorStore, EmbeddingService embeddingService)
9+
public class AISearchService(
10+
VectorStore vectorStore,
11+
EmbeddingService embeddingService,
12+
ILogger<AISearchService> logger)
713
{
814
// TODO: Implement Hybrid Search functionality, may need to switch db providers to support full text search?
915

10-
public async Task<IAsyncEnumerable<VectorSearchResult<BookContentChunk>>> ExecuteVectorSearch(string query, string? collectionName = null)
16+
public async Task<IReadOnlyList<VectorSearchResult<BookContentChunk>>> ExecuteVectorSearch(
17+
string query, string? collectionName = null, CancellationToken cancellationToken = default)
1118
{
1219
collectionName ??= EmbeddingService.CollectionName;
1320

1421
VectorStoreCollection<string, BookContentChunk> collection = vectorStore.GetCollection<string, BookContentChunk>(collectionName);
1522

16-
ReadOnlyMemory<float> searchVector = await embeddingService.GenerateEmbeddingAsync(query);
23+
ReadOnlyMemory<float> searchVector = await embeddingService.GenerateEmbeddingAsync(query, cancellationToken);
1724

1825
var vectorSearchOptions = new VectorSearchOptions<BookContentChunk>
1926
{
2027
VectorProperty = x => x.TextEmbedding,
2128
};
2229

23-
var searchResults = collection.SearchAsync(searchVector, options: vectorSearchOptions, top: 3);
24-
25-
return searchResults;
30+
for (int attempt = 0; attempt <= 1; attempt++)
31+
{
32+
try
33+
{
34+
var results = new List<VectorSearchResult<BookContentChunk>>();
35+
await foreach (var result in collection.SearchAsync(searchVector, options: vectorSearchOptions, top: 3, cancellationToken: cancellationToken))
36+
{
37+
results.Add(result);
38+
}
39+
return results;
40+
}
41+
catch (PostgresException ex) when (ex.SqlState == "28000" && attempt == 0)
42+
{
43+
// The pooled connection held an expired Entra ID token. Npgsql automatically
44+
// removes the broken connection from the pool on error — no manual pool clearing
45+
// needed (clearing would evict all healthy connections, hurting concurrent users).
46+
// The retry opens a fresh physical connection, which calls UsePasswordProvider
47+
// and gets a new token from DefaultAzureCredential.
48+
logger.LogWarning(ex, "Entra ID token expired on pooled connection (SqlState 28000); retrying once.");
49+
}
50+
}
51+
52+
throw new UnreachableException("Retry loop exited without returning or throwing.");
2653
}
2754
}

0 commit comments

Comments
 (0)