Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions Directory.Packages.props
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
<PackageVersion Include="Azure.Extensions.AspNetCore.Configuration.Secrets" Version="1.5.0" />
<PackageVersion Include="Azure.Identity" Version="1.21.0" />
<PackageVersion Include="Azure.Monitor.OpenTelemetry.AspNetCore" Version="1.4.0" />
<PackageVersion Include="Microsoft.ApplicationInsights.Profiler.AspNetCore" Version="3.0.2" />
<PackageVersion Include="TUnit" Version="1.33.0" />
<PackageVersion Include="EssentialCSharp.Shared.Models" Version="$(ToolingPackagesVersion)" />
<PackageVersion Include="HtmlAgilityPack" Version="1.12.4" />
Expand All @@ -39,9 +38,13 @@
<PackageVersion Include="Microsoft.EntityFrameworkCore.Sqlite" Version="10.0.5" />
<PackageVersion Include="Microsoft.EntityFrameworkCore.SqlServer" Version="10.0.5" />
<PackageVersion Include="Microsoft.EntityFrameworkCore.Tools" Version="10.0.3" />
<PackageVersion Include="Microsoft.Extensions.Http.Resilience" Version="10.2.0" />
<PackageVersion Include="Microsoft.SemanticKernel" Version="$(SemanticKernelVersion)" />
<PackageVersion Include="Microsoft.SemanticKernel.Connectors.PgVector" Version="$(SemanticKernelVersion)-preview" />
<PackageVersion Include="Microsoft.SourceLink.GitHub" Version="10.0.201" />
<PackageVersion Include="Microsoft.SourceLink.GitHub" Version="10.0.202" />
<!-- Pin to patched versions; NuGet.Protocol 6.12.5+ fixes GHSA-g4vj-cjjj-v7hg (pulled in by Microsoft.VisualStudio.Web.CodeGeneration.Design) -->
<PackageVersion Include="NuGet.Packaging" Version="6.12.5" />
<PackageVersion Include="NuGet.Protocol" Version="6.12.5" />
<PackageVersion Include="Microsoft.VisualStudio.Azure.Containers.Tools.Targets" Version="1.23.0" />
<PackageVersion Include="Microsoft.VisualStudio.Web.CodeGeneration.Design" Version="10.0.2" />
<PackageVersion Include="ModelContextProtocol" Version="0.3.0-preview.4" />
Expand All @@ -51,6 +54,12 @@
<PackageVersion Include="System.CommandLine" Version="2.0.5" />
<PackageVersion Include="Newtonsoft.Json" Version="13.0.4" />
<PackageVersion Include="Octokit" Version="14.0.0" />
<PackageVersion Include="OpenTelemetry.Exporter.OpenTelemetryProtocol" Version="1.15.2" />
<PackageVersion Include="OpenTelemetry.Extensions.Hosting" Version="1.15.2" />
<PackageVersion Include="OpenTelemetry.Instrumentation.AspNetCore" Version="1.15.1" />
<PackageVersion Include="OpenTelemetry.Instrumentation.Http" Version="1.15.0" />
<PackageVersion Include="OpenTelemetry.Instrumentation.Runtime" Version="1.15.0" />
<PackageVersion Include="OpenTelemetry.Instrumentation.SqlClient" Version="1.15.1" />
<PackageVersion Include="DotnetSitemapGenerator" Version="2.0.0" />
</ItemGroup>
</Project>
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,10 @@ public static IServiceCollection AddAzureOpenAIServices(

/// <summary>
/// Adds PostgreSQL vector store with managed identity authentication support.
/// NOTE: Token is obtained once at startup and will expire after ~1 hour.
/// For long-running applications, consider implementing token refresh logic.
/// Uses per-connection token refresh via <c>UsePasswordProvider</c>, which calls
/// <see cref="TokenCredential.GetTokenAsync"/> on every new physical connection.
/// <see cref="DefaultAzureCredential"/> caches tokens internally and auto-refreshes
/// ~5 minutes before expiry, so this does not add Azure AD overhead.
/// </summary>
/// <param name="services">The service collection to add services to</param>
/// <param name="connectionString">The PostgreSQL connection string (without password)</param>
Expand All @@ -115,36 +117,51 @@ private static IServiceCollection AddPostgresVectorStoreWithManagedIdentity(
{
credential ??= new DefaultAzureCredential();

// Parse the connection string to extract host, database, and username
var builder = new NpgsqlConnectionStringBuilder(connectionString);

// Check if this is an Azure PostgreSQL connection (contains .postgres.database.azure.com)
bool isAzurePostgres = builder.Host?.Contains(".postgres.database.azure.com", StringComparison.OrdinalIgnoreCase) ?? false;

if (isAzurePostgres && string.IsNullOrEmpty(builder.Password))
{
// Get access token for Azure PostgreSQL using managed identity
var tokenRequestContext = new TokenRequestContext(_PostgresScopes);
var accessToken = credential.GetToken(tokenRequestContext, default);

// Set the password to the access token
builder.Password = accessToken.Token;

// Ensure SSL is enabled for Azure
if (builder.SslMode == SslMode.Disable)
{
builder.SslMode = SslMode.Require;
}

connectionString = builder.ToString();
}

// Register NpgsqlDataSource with UseVector() enabled - this is critical for pgvector support
services.AddSingleton<NpgsqlDataSource>(sp =>
{
var connBuilder = new NpgsqlConnectionStringBuilder(connectionString);
bool isAzurePostgres = connBuilder.Host?.Contains(".postgres.database.azure.com",
StringComparison.OrdinalIgnoreCase) ?? false;

var dataSourceBuilder = new NpgsqlDataSourceBuilder(connectionString);
// IMPORTANT: UseVector() must be called to enable pgvector support
dataSourceBuilder.UseVector();

if (isAzurePostgres && string.IsNullOrEmpty(connBuilder.Password))
{
// Ensure SSL is enabled for Azure PostgreSQL
if (dataSourceBuilder.ConnectionStringBuilder.SslMode < SslMode.Require)
{
dataSourceBuilder.ConnectionStringBuilder.SslMode = SslMode.Require;
}

var tokenRequestContext = new TokenRequestContext(_PostgresScopes);

// UsePasswordProvider is called for every new physical connection.
// DefaultAzureCredential caches tokens internally and auto-refreshes ~5 min before
// expiry — no extra Azure AD load. This is the approach recommended by the Npgsql
// docs for cloud providers that implement their own caching (Azure MI does).
// UsePeriodicPasswordProvider is only for token sources without built-in caching.
// See: https://www.npgsql.org/doc/security.html
// See: https://github.com/npgsql/npgsql/issues/5186
//
// Note: The username is expected to be set in the connection string already
// (Aspire sets it during deployment for Azure PostgreSQL Flexible Server).
// If a standalone username-extraction fallback is ever needed, use the
// Microsoft.Azure.PostgreSQL.Auth package (UseEntraAuthentication extension)
// once it ships on NuGet.
dataSourceBuilder.UsePasswordProvider(
passwordProvider: _ => credential.GetToken(tokenRequestContext, default).Token,
passwordProviderAsync: async (_, ct) =>
(await credential.GetTokenAsync(tokenRequestContext, ct)).Token);

// Recycle pooled connections after 50 min, well before the 60-min JWT token TTL.
// Combined with UsePasswordProvider (called on every new physical connection),
// this ensures no pooled connection ever holds an expired token.
dataSourceBuilder.ConnectionStringBuilder.ConnectionLifetime = 3000;
}

return dataSourceBuilder.Build();
});

Expand Down
4 changes: 2 additions & 2 deletions EssentialCSharp.Chat.Shared/Services/AIChatService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -118,14 +118,14 @@ private async Task<string> EnrichPromptWithContext(string prompt, bool enableCon
return prompt;
}

var searchResults = await _SearchService.ExecuteVectorSearch(prompt);
var searchResults = await _SearchService.ExecuteVectorSearch(prompt, cancellationToken: cancellationToken);
var contextualInfo = new System.Text.StringBuilder();

contextualInfo.AppendLine("## Contextual Information");
contextualInfo.AppendLine("The following information might be relevant to your question:");
contextualInfo.AppendLine();

await foreach (var result in searchResults)
foreach (var result in searchResults)
{
contextualInfo.AppendLine(System.Globalization.CultureInfo.InvariantCulture, $"**From: {result.Record.Heading}**");
contextualInfo.AppendLine(result.Record.ChunkText);
Expand Down
41 changes: 34 additions & 7 deletions EssentialCSharp.Chat.Shared/Services/AISearchService.cs
Original file line number Diff line number Diff line change
@@ -1,27 +1,54 @@
using EssentialCSharp.Chat.Common.Models;
using System.Diagnostics;
using EssentialCSharp.Chat.Common.Models;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.VectorData;
using Npgsql;

namespace EssentialCSharp.Chat.Common.Services;

public class AISearchService(VectorStore vectorStore, EmbeddingService embeddingService)
public class AISearchService(
VectorStore vectorStore,
EmbeddingService embeddingService,
ILogger<AISearchService> logger)
{
// TODO: Implement Hybrid Search functionality, may need to switch db providers to support full text search?

public async Task<IAsyncEnumerable<VectorSearchResult<BookContentChunk>>> ExecuteVectorSearch(string query, string? collectionName = null)
public async Task<IReadOnlyList<VectorSearchResult<BookContentChunk>>> ExecuteVectorSearch(
string query, string? collectionName = null, CancellationToken cancellationToken = default)
{
collectionName ??= EmbeddingService.CollectionName;

VectorStoreCollection<string, BookContentChunk> collection = vectorStore.GetCollection<string, BookContentChunk>(collectionName);

ReadOnlyMemory<float> searchVector = await embeddingService.GenerateEmbeddingAsync(query);
ReadOnlyMemory<float> searchVector = await embeddingService.GenerateEmbeddingAsync(query, cancellationToken);

var vectorSearchOptions = new VectorSearchOptions<BookContentChunk>
{
VectorProperty = x => x.TextEmbedding,
};

var searchResults = collection.SearchAsync(searchVector, options: vectorSearchOptions, top: 3);

return searchResults;
for (int attempt = 0; attempt <= 1; attempt++)
{
try
{
var results = new List<VectorSearchResult<BookContentChunk>>();
await foreach (var result in collection.SearchAsync(searchVector, options: vectorSearchOptions, top: 3, cancellationToken: cancellationToken))
{
results.Add(result);
}
return results;
}
catch (PostgresException ex) when (ex.SqlState == "28000" && attempt == 0)
{
// The pooled connection held an expired Entra ID token. Npgsql automatically
// removes the broken connection from the pool on error — no manual pool clearing
// needed (clearing would evict all healthy connections, hurting concurrent users).
// The retry opens a fresh physical connection, which calls UsePasswordProvider
// and gets a new token from DefaultAzureCredential.
logger.LogWarning(ex, "Entra ID token expired on pooled connection (SqlState 28000); retrying once.");
}
Comment thread
BenjaminMichaelis marked this conversation as resolved.
}

throw new UnreachableException("Retry loop exited without returning or throwing.");
}
}
9 changes: 9 additions & 0 deletions EssentialCSharp.Web.Tests/WebApplicationFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,15 @@ protected override void ConfigureWebHost(IWebHostBuilder builder)
services.Remove(dbConnectionDescriptor);
}

// Remove DatabaseMigrationService: it calls MigrateAsync which conflicts
// with EnsureCreated() used below for the in-memory SQLite test database.
ServiceDescriptor? migrationServiceDescriptor = services.SingleOrDefault(
d => d.ImplementationType == typeof(DatabaseMigrationService));
if (migrationServiceDescriptor != null)
{
services.Remove(migrationServiceDescriptor);
}

_Connection = new SqliteConnection(SqlConnectionString);
_Connection.Open();

Expand Down
12 changes: 2 additions & 10 deletions EssentialCSharp.Web/DatabaseMigrationService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,7 @@ public class DatabaseMigrationService(IServiceScopeFactory services) : Backgroun
protected override async Task ExecuteAsync(CancellationToken stoppingToken)
{
using IServiceScope scope = Services.CreateScope();
EssentialCSharpWebContext? context = scope.ServiceProvider.GetRequiredService<EssentialCSharpWebContext>()
?? throw new InvalidOperationException($"EssentialCSharpWebContext not found for {nameof(DatabaseMigrationService)}");
if (!context.Database.GetPendingMigrations().Contains("20231021170008_CreateIdentitySchema"))
{
await context.Database.MigrateAsync(stoppingToken);
}
else
{
await context.Database.EnsureCreatedAsync(cancellationToken: stoppingToken);
}
EssentialCSharpWebContext context = scope.ServiceProvider.GetRequiredService<EssentialCSharpWebContext>();
await context.Database.MigrateAsync(stoppingToken);
}
}
23 changes: 19 additions & 4 deletions EssentialCSharp.Web/EssentialCSharp.Web.csproj
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
<Project Sdk="Microsoft.NET.Sdk.Web">
<Project Sdk="Microsoft.NET.Sdk.Web">
<PropertyGroup>
<TargetFramework>net10.0</TargetFramework>
<!--
CA1873: Logging argument evaluation - suppress for now; affects 14+ scaffolded Identity pages.
TODO: Address by converting to LoggerMessage source generators in a follow-up.
-->
<NoWarn>$(NoWarn);CA1873</NoWarn>
<!-- CA1873: Logging argument evaluation - scaffolded Identity pages.
LOGGEN036: LoggerMessage source generator - JsonResult lacks ToString/IConvertible (.NET 10). -->
<NoWarn>$(NoWarn);CA1873;LOGGEN036</NoWarn>
</PropertyGroup>

<ItemGroup>
Expand All @@ -27,27 +29,40 @@
</ItemGroup>

<ItemGroup>
<PackageReference Include="Azure.Monitor.OpenTelemetry.AspNetCore" />
<PackageReference Include="AspNet.Security.OAuth.GitHub" />
<PackageReference Include="Azure.Monitor.OpenTelemetry.AspNetCore" />
<PackageReference Include="EssentialCSharp.Shared.Models" />
<PackageReference Include="HtmlAgilityPack" />
<PackageReference Include="IntelliTect.Multitool" />
<PackageReference Include="Mailjet.Api" />
<PackageReference Include="Microsoft.CodeAnalysis.CSharp" />
<PackageReference Include="Microsoft.AspNetCore.Authentication.MicrosoftAccount" />
<PackageReference Include="Microsoft.ApplicationInsights.Profiler.AspNetCore" />
<PackageReference Include="Microsoft.AspNetCore.Identity.EntityFrameworkCore" />
<PackageReference Include="Microsoft.AspNetCore.Identity.UI" />
<PackageReference Include="Microsoft.AspNetCore.Mvc.Razor.RuntimeCompilation" />
<PackageReference Include="Microsoft.EntityFrameworkCore.SqlServer" />
<PackageReference Include="Microsoft.EntityFrameworkCore.Tools">
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
</PackageReference>
<PackageReference Include="Microsoft.Extensions.Http.Resilience" />
<PackageReference Include="Microsoft.VisualStudio.Web.CodeGeneration.Design" />
<!-- Pin patched NuGet packages; transitive via CodeGeneration.Design; GHSA-g4vj-cjjj-v7hg -->
<PackageReference Include="NuGet.Packaging">
<PrivateAssets>all</PrivateAssets>
</PackageReference>
<PackageReference Include="NuGet.Protocol">
<PrivateAssets>all</PrivateAssets>
</PackageReference>
<PackageReference Include="Newtonsoft.Json" />
<PackageReference Include="Microsoft.VisualStudio.Azure.Containers.Tools.Targets" />
<PackageReference Include="Octokit" />
<PackageReference Include="DotnetSitemapGenerator" />
<PackageReference Include="OpenTelemetry.Exporter.OpenTelemetryProtocol" />
<PackageReference Include="OpenTelemetry.Extensions.Hosting" />
<PackageReference Include="OpenTelemetry.Instrumentation.AspNetCore" />
<PackageReference Include="OpenTelemetry.Instrumentation.Http" />
<PackageReference Include="OpenTelemetry.Instrumentation.Runtime" />
<PackageReference Include="OpenTelemetry.Instrumentation.SqlClient" />
</ItemGroup>
<ItemGroup>
<Content Update="wwwroot\images\00mindmap.svg">
Expand Down
Loading
Loading