Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions Directory.Packages.props
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
<PackageVersion Include="Azure.Extensions.AspNetCore.Configuration.Secrets" Version="1.5.0" />
<PackageVersion Include="Azure.Identity" Version="1.21.0" />
<PackageVersion Include="Azure.Monitor.OpenTelemetry.AspNetCore" Version="1.4.0" />
<PackageVersion Include="Microsoft.ApplicationInsights.Profiler.AspNetCore" Version="3.0.2" />
<PackageVersion Include="TUnit" Version="1.33.0" />
<PackageVersion Include="EssentialCSharp.Shared.Models" Version="$(ToolingPackagesVersion)" />
<PackageVersion Include="HtmlAgilityPack" Version="1.12.4" />
Expand All @@ -39,9 +38,13 @@
<PackageVersion Include="Microsoft.EntityFrameworkCore.Sqlite" Version="10.0.5" />
<PackageVersion Include="Microsoft.EntityFrameworkCore.SqlServer" Version="10.0.5" />
<PackageVersion Include="Microsoft.EntityFrameworkCore.Tools" Version="10.0.3" />
<PackageVersion Include="Microsoft.Extensions.Http.Resilience" Version="10.2.0" />
<PackageVersion Include="Microsoft.SemanticKernel" Version="$(SemanticKernelVersion)" />
<PackageVersion Include="Microsoft.SemanticKernel.Connectors.PgVector" Version="$(SemanticKernelVersion)-preview" />
<PackageVersion Include="Microsoft.SourceLink.GitHub" Version="10.0.201" />
<PackageVersion Include="Microsoft.SourceLink.GitHub" Version="10.0.202" />
<!-- Pin to patched versions; NuGet.Protocol 6.12.5+ fixes GHSA-g4vj-cjjj-v7hg (pulled in by Microsoft.VisualStudio.Web.CodeGeneration.Design) -->
<PackageVersion Include="NuGet.Packaging" Version="6.12.5" />
<PackageVersion Include="NuGet.Protocol" Version="6.12.5" />
<PackageVersion Include="Microsoft.VisualStudio.Azure.Containers.Tools.Targets" Version="1.23.0" />
<PackageVersion Include="Microsoft.VisualStudio.Web.CodeGeneration.Design" Version="10.0.2" />
<PackageVersion Include="ModelContextProtocol" Version="0.3.0-preview.4" />
Expand All @@ -51,6 +54,12 @@
<PackageVersion Include="System.CommandLine" Version="2.0.5" />
<PackageVersion Include="Newtonsoft.Json" Version="13.0.4" />
<PackageVersion Include="Octokit" Version="14.0.0" />
<PackageVersion Include="OpenTelemetry.Exporter.OpenTelemetryProtocol" Version="1.15.2" />
<PackageVersion Include="OpenTelemetry.Extensions.Hosting" Version="1.15.2" />
<PackageVersion Include="OpenTelemetry.Instrumentation.AspNetCore" Version="1.15.1" />
<PackageVersion Include="OpenTelemetry.Instrumentation.Http" Version="1.15.0" />
<PackageVersion Include="OpenTelemetry.Instrumentation.Runtime" Version="1.15.0" />
<PackageVersion Include="OpenTelemetry.Instrumentation.SqlClient" Version="1.15.1" />
<PackageVersion Include="DotnetSitemapGenerator" Version="2.0.0" />
</ItemGroup>
</Project>
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,10 @@ public static IServiceCollection AddAzureOpenAIServices(

/// <summary>
/// Adds PostgreSQL vector store with managed identity authentication support.
/// NOTE: Token is obtained once at startup and will expire after ~1 hour.
/// For long-running applications, consider implementing token refresh logic.
/// Uses per-connection token refresh via <c>UsePasswordProvider</c>, which calls
/// <see cref="TokenCredential.GetTokenAsync"/> on every new physical connection.
/// <see cref="DefaultAzureCredential"/> caches tokens internally and auto-refreshes
/// ~5 minutes before expiry, so this does not add Azure AD overhead.
/// </summary>
/// <param name="services">The service collection to add services to</param>
/// <param name="connectionString">The PostgreSQL connection string (without password)</param>
Expand All @@ -115,36 +117,51 @@ private static IServiceCollection AddPostgresVectorStoreWithManagedIdentity(
{
credential ??= new DefaultAzureCredential();

// Parse the connection string to extract host, database, and username
var builder = new NpgsqlConnectionStringBuilder(connectionString);

// Check if this is an Azure PostgreSQL connection (contains .postgres.database.azure.com)
bool isAzurePostgres = builder.Host?.Contains(".postgres.database.azure.com", StringComparison.OrdinalIgnoreCase) ?? false;

if (isAzurePostgres && string.IsNullOrEmpty(builder.Password))
{
// Get access token for Azure PostgreSQL using managed identity
var tokenRequestContext = new TokenRequestContext(_PostgresScopes);
var accessToken = credential.GetToken(tokenRequestContext, default);

// Set the password to the access token
builder.Password = accessToken.Token;

// Ensure SSL is enabled for Azure
if (builder.SslMode == SslMode.Disable)
{
builder.SslMode = SslMode.Require;
}

connectionString = builder.ToString();
}

// Register NpgsqlDataSource with UseVector() enabled - this is critical for pgvector support
services.AddSingleton<NpgsqlDataSource>(sp =>
{
var connBuilder = new NpgsqlConnectionStringBuilder(connectionString);
bool isAzurePostgres = connBuilder.Host?.Contains(".postgres.database.azure.com",
StringComparison.OrdinalIgnoreCase) ?? false;

var dataSourceBuilder = new NpgsqlDataSourceBuilder(connectionString);
// IMPORTANT: UseVector() must be called to enable pgvector support
dataSourceBuilder.UseVector();

if (isAzurePostgres && string.IsNullOrEmpty(connBuilder.Password))
{
// Ensure SSL is enabled for Azure PostgreSQL
if (dataSourceBuilder.ConnectionStringBuilder.SslMode < SslMode.Require)
{
dataSourceBuilder.ConnectionStringBuilder.SslMode = SslMode.Require;
}

var tokenRequestContext = new TokenRequestContext(_PostgresScopes);

// UsePasswordProvider is called for every new physical connection.
// DefaultAzureCredential caches tokens internally and auto-refreshes ~5 min before
// expiry — no extra Azure AD load. This is the approach recommended by the Npgsql
// docs for cloud providers that implement their own caching (Azure MI does).
// UsePeriodicPasswordProvider is only for token sources without built-in caching.
// See: https://www.npgsql.org/doc/security.html
// See: https://github.com/npgsql/npgsql/issues/5186
//
// Note: The username is expected to be set in the connection string already
// (Aspire sets it during deployment for Azure PostgreSQL Flexible Server).
// If a standalone username-extraction fallback is ever needed, use the
// Microsoft.Azure.PostgreSQL.Auth package (UseEntraAuthentication extension)
// once it ships on NuGet.
dataSourceBuilder.UsePasswordProvider(
passwordProvider: _ => credential.GetToken(tokenRequestContext, default).Token,
passwordProviderAsync: async (_, ct) =>
(await credential.GetTokenAsync(tokenRequestContext, ct)).Token);

// Recycle pooled connections after 50 min, well before the 60-min JWT token TTL.
// Combined with UsePasswordProvider (called on every new physical connection),
// this ensures no pooled connection ever holds an expired token.
dataSourceBuilder.ConnectionStringBuilder.ConnectionLifetime = 3000;
}

return dataSourceBuilder.Build();
});

Expand Down
4 changes: 2 additions & 2 deletions EssentialCSharp.Chat.Shared/Services/AIChatService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -118,14 +118,14 @@ private async Task<string> EnrichPromptWithContext(string prompt, bool enableCon
return prompt;
}

var searchResults = await _SearchService.ExecuteVectorSearch(prompt);
var searchResults = await _SearchService.ExecuteVectorSearch(prompt, cancellationToken: cancellationToken);
var contextualInfo = new System.Text.StringBuilder();

contextualInfo.AppendLine("## Contextual Information");
contextualInfo.AppendLine("The following information might be relevant to your question:");
contextualInfo.AppendLine();

await foreach (var result in searchResults)
foreach (var result in searchResults)
{
contextualInfo.AppendLine(System.Globalization.CultureInfo.InvariantCulture, $"**From: {result.Record.Heading}**");
contextualInfo.AppendLine(result.Record.ChunkText);
Expand Down
41 changes: 34 additions & 7 deletions EssentialCSharp.Chat.Shared/Services/AISearchService.cs
Original file line number Diff line number Diff line change
@@ -1,27 +1,54 @@
using EssentialCSharp.Chat.Common.Models;
using System.Diagnostics;
using EssentialCSharp.Chat.Common.Models;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.VectorData;
using Npgsql;

namespace EssentialCSharp.Chat.Common.Services;

public class AISearchService(VectorStore vectorStore, EmbeddingService embeddingService)
public class AISearchService(
VectorStore vectorStore,
EmbeddingService embeddingService,
ILogger<AISearchService> logger)
{
// TODO: Implement Hybrid Search functionality, may need to switch db providers to support full text search?

public async Task<IAsyncEnumerable<VectorSearchResult<BookContentChunk>>> ExecuteVectorSearch(string query, string? collectionName = null)
public async Task<IReadOnlyList<VectorSearchResult<BookContentChunk>>> ExecuteVectorSearch(
string query, string? collectionName = null, CancellationToken cancellationToken = default)
{
collectionName ??= EmbeddingService.CollectionName;

VectorStoreCollection<string, BookContentChunk> collection = vectorStore.GetCollection<string, BookContentChunk>(collectionName);

ReadOnlyMemory<float> searchVector = await embeddingService.GenerateEmbeddingAsync(query);
ReadOnlyMemory<float> searchVector = await embeddingService.GenerateEmbeddingAsync(query, cancellationToken);

var vectorSearchOptions = new VectorSearchOptions<BookContentChunk>
{
VectorProperty = x => x.TextEmbedding,
};

var searchResults = collection.SearchAsync(searchVector, options: vectorSearchOptions, top: 3);

return searchResults;
for (int attempt = 0; attempt <= 1; attempt++)
{
try
{
var results = new List<VectorSearchResult<BookContentChunk>>();
await foreach (var result in collection.SearchAsync(searchVector, options: vectorSearchOptions, top: 3, cancellationToken: cancellationToken))
{
results.Add(result);
}
return results;
}
catch (PostgresException ex) when (ex.SqlState == "28000" && attempt == 0)
{
// The pooled connection held an expired Entra ID token. Npgsql automatically
// removes the broken connection from the pool on error — no manual pool clearing
// needed (clearing would evict all healthy connections, hurting concurrent users).
// The retry opens a fresh physical connection, which calls UsePasswordProvider
// and gets a new token from DefaultAzureCredential.
logger.LogWarning(ex, "Entra ID token expired on pooled connection (SqlState 28000); retrying once.");
}
Comment thread
BenjaminMichaelis marked this conversation as resolved.
}

throw new UnreachableException("Retry loop exited without returning or throwing.");
}
}
109 changes: 109 additions & 0 deletions EssentialCSharp.Chat.Tests/AISearchServiceTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
using EssentialCSharp.Chat.Common.Models;
using EssentialCSharp.Chat.Common.Services;
using Microsoft.Extensions.AI;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.VectorData;
using Moq;
using Moq.Language.Flow;
using Npgsql;

namespace EssentialCSharp.Chat.Tests;

public class AISearchServiceTests
{
private static readonly BookContentChunk _TestChunk = new() { Id = "test-1", ChunkText = "test" };

private static (AISearchService svc, Mock<VectorStoreCollection<string, BookContentChunk>> collectionMock)
CreateService()
{
var collectionMock = new Mock<VectorStoreCollection<string, BookContentChunk>>();

var vectorStoreMock = new Mock<VectorStore>();
vectorStoreMock
.Setup(vs => vs.GetCollection<string, BookContentChunk>(It.IsAny<string>(), It.IsAny<VectorStoreCollectionDefinition?>()))
.Returns(collectionMock.Object);

// IEmbeddingGenerator<string, Embedding<float>>.GenerateAsync is the batch interface method
// called internally by the single-value extension used in EmbeddingService.GenerateEmbeddingAsync.
var embGenMock = new Mock<IEmbeddingGenerator<string, Embedding<float>>>();
embGenMock
.Setup(g => g.GenerateAsync(
It.IsAny<IEnumerable<string>>(),
It.IsAny<EmbeddingGenerationOptions?>(),
It.IsAny<CancellationToken>()))
.ReturnsAsync(new GeneratedEmbeddings<Embedding<float>>([new Embedding<float>(new float[1536])]));

var embeddingService = new EmbeddingService(vectorStoreMock.Object, embGenMock.Object);
var loggerMock = new Mock<ILogger<AISearchService>>();

return (new AISearchService(vectorStoreMock.Object, embeddingService, loggerMock.Object), collectionMock);
}

private static async IAsyncEnumerable<VectorSearchResult<BookContentChunk>> OneResultStream()
{
yield return new VectorSearchResult<BookContentChunk>(_TestChunk, 0.9f);
await Task.CompletedTask;
}

private static ISetup<VectorStoreCollection<string, BookContentChunk>, IAsyncEnumerable<VectorSearchResult<BookContentChunk>>>
SetupSearch(Mock<VectorStoreCollection<string, BookContentChunk>> mock) =>
mock.Setup(c => c.SearchAsync(
It.IsAny<ReadOnlyMemory<float>>(),
It.IsAny<int>(),
It.IsAny<VectorSearchOptions<BookContentChunk>?>(),
It.IsAny<CancellationToken>()));

[Test]
public async Task ExecuteVectorSearch_HappyPath_ReturnsResultsWithoutRetry()
{
var (svc, collectionMock) = CreateService();
int callCount = 0;

SetupSearch(collectionMock).Returns(() => { callCount++; return OneResultStream(); });

var results = await svc.ExecuteVectorSearch("test query");

await Assert.That(results.Count).IsEqualTo(1);
await Assert.That(callCount).IsEqualTo(1);
}

[Test]
public async Task ExecuteVectorSearch_RetriesOnce_WhenFirstAttemptThrows28000()
{
var (svc, collectionMock) = CreateService();
int callCount = 0;

SetupSearch(collectionMock).Returns(() =>
{
callCount++;
if (callCount == 1)
throw new PostgresException("auth token expired", "FATAL", "FATAL", "28000");
return OneResultStream();
});

var results = await svc.ExecuteVectorSearch("test query");

await Assert.That(results.Count).IsEqualTo(1);
await Assert.That(callCount).IsEqualTo(2);
}

[Test]
public async Task ExecuteVectorSearch_DoesNotRetry_WhenSqlStateIsNot28000()
{
var (svc, collectionMock) = CreateService();

SetupSearch(collectionMock).Returns(() => throw new PostgresException("table not found", "ERROR", "ERROR", "42P01"));

await Assert.ThrowsAsync<PostgresException>(() => svc.ExecuteVectorSearch("test query"));
}

[Test]
public async Task ExecuteVectorSearch_PropagatesException_WhenBothAttemptsFail28000()
{
var (svc, collectionMock) = CreateService();

SetupSearch(collectionMock).Returns(() => throw new PostgresException("auth failed", "FATAL", "FATAL", "28000"));

await Assert.ThrowsAsync<PostgresException>(() => svc.ExecuteVectorSearch("test query"));
}
}
3 changes: 2 additions & 1 deletion EssentialCSharp.Web.Tests/FunctionalTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ public class FunctionalTests(WebApplicationFactory factory)
[Arguments("/hello-world")]
[Arguments("/hello-world#hello-world")]
[Arguments("/guidelines")]
[Arguments("/healthz")]
[Arguments("/health")]
[Arguments("/alive")]
public async Task WhenTheApplicationStarts_ItCanLoadLoadPages(string relativeUrl)
{
HttpClient client = factory.CreateClient();
Expand Down
32 changes: 2 additions & 30 deletions EssentialCSharp.Web.Tests/ListingSourceCodeServiceTests.cs
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
using EssentialCSharp.Web.Models;
using EssentialCSharp.Web.Services;
using Microsoft.AspNetCore.Hosting;
using Microsoft.Extensions.FileProviders;
using Moq;
using Moq.AutoMock;

namespace EssentialCSharp.Web.Tests;

Expand Down Expand Up @@ -130,30 +126,6 @@ public async Task GetListingsByChapterAsync_WithInvalidChapter_ReturnsEmptyList(
await Assert.That(results).IsEmpty();
}

private static ListingSourceCodeService CreateService()
{
DirectoryInfo testDataRoot = GetTestDataPath();

AutoMocker mocker = new();
Mock<IWebHostEnvironment> mockWebHostEnvironment = mocker.GetMock<IWebHostEnvironment>();
mockWebHostEnvironment.Setup(m => m.ContentRootPath).Returns(testDataRoot.FullName);
mockWebHostEnvironment.Setup(m => m.ContentRootFileProvider).Returns(new PhysicalFileProvider(testDataRoot.FullName));

return mocker.CreateInstance<ListingSourceCodeService>();
}

private static DirectoryInfo GetTestDataPath()
{
string baseDirectory = AppContext.BaseDirectory;
string testDataPath = Path.Join(baseDirectory, "TestData");

DirectoryInfo testDataDirectory = new(testDataPath);

if (!testDataDirectory.Exists)
{
throw new InvalidOperationException($"TestData directory not found at: {testDataDirectory.FullName}");
}

return testDataDirectory;
}
private static ListingSourceCodeService CreateService() =>
TestListingSourceCodeServiceHelper.CreateService();
}
29 changes: 29 additions & 0 deletions EssentialCSharp.Web.Tests/TestListingSourceCodeServiceHelper.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
using EssentialCSharp.Web.Services;
using Microsoft.AspNetCore.Hosting;
using Microsoft.Extensions.FileProviders;
using Moq.AutoMock;

namespace EssentialCSharp.Web.Tests;

internal static class TestListingSourceCodeServiceHelper
{
internal static string ResolveTestDataPath()
{
string testDataPath = Path.Join(AppContext.BaseDirectory, "TestData");
if (!Directory.Exists(testDataPath))
throw new InvalidOperationException($"TestData directory not found at: {testDataPath}");
return testDataPath;
}

internal static ListingSourceCodeService CreateService()
{
string testDataPath = ResolveTestDataPath();

AutoMocker mocker = new();
mocker.Setup<IWebHostEnvironment, string>(m => m.ContentRootPath).Returns(testDataPath);
mocker.Setup<IWebHostEnvironment, IFileProvider>(m => m.ContentRootFileProvider)
.Returns(new PhysicalFileProvider(testDataPath));

return mocker.CreateInstance<ListingSourceCodeService>();
}
}
Loading
Loading