Skip to content

Add test for chat service && pull out document search into an individual service #270

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Feb 2, 2024
2 changes: 1 addition & 1 deletion app/backend/Extensions/SearchClientExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ internal static async Task<SupportingContentRecord[]> QueryDocumentsAsync(
Size = top,
};

if (embedding != null && overrides?.RetrievalMode != "Text")
if (embedding != null && overrides?.RetrievalMode != RetrievalMode.Text)
{
var k = useSemanticRanker ? 50 : top;
var vectorQuery = new VectorizedQuery(embedding)
Expand Down
13 changes: 7 additions & 6 deletions app/backend/Extensions/ServiceCollectionExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,19 @@ internal static IServiceCollection AddAzureServices(this IServiceCollection serv
return sp.GetRequiredService<BlobServiceClient>().GetBlobContainerClient(azureStorageContainer);
});

services.AddSingleton<SearchClient>(sp =>
services.AddSingleton<IDocumentService, AzureDocumentService>(sp =>
{
var config = sp.GetRequiredService<IConfiguration>();
var (azureSearchServiceEndpoint, azureSearchIndex) =
(config["AzureSearchServiceEndpoint"], config["AzureSearchIndex"]);

var azureSearchServiceEndpoint = config["AzureSearchServiceEndpoint"];
ArgumentNullException.ThrowIfNullOrEmpty(azureSearchServiceEndpoint);

var azureSearchIndex = config["AzureSearchIndex"];
ArgumentNullException.ThrowIfNullOrEmpty(azureSearchIndex);

var searchClient = new SearchClient(
new Uri(azureSearchServiceEndpoint), azureSearchIndex, s_azureCredential);
new Uri(azureSearchServiceEndpoint), azureSearchIndex, s_azureCredential);

return searchClient;
return new AzureDocumentService(searchClient);
});

services.AddSingleton<DocumentAnalysisClient>(sp =>
Expand Down
8 changes: 4 additions & 4 deletions app/backend/Services/ReadRetrieveReadChatService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@ namespace MinimalApi.Services;

public class ReadRetrieveReadChatService
{
private readonly SearchClient _searchClient;
private readonly IDocumentService _searchClient;
private readonly IKernel _kernel;
private readonly IConfiguration _configuration;

public ReadRetrieveReadChatService(
SearchClient searchClient,
IDocumentService searchClient,
OpenAIClient client,
IConfiguration configuration)
{
Expand Down Expand Up @@ -45,15 +45,15 @@ public async Task<ApproachResponse> ReplyAsync(
var question = history.LastOrDefault()?.User is { } userQuestion
? userQuestion
: throw new InvalidOperationException("Use question is null");
if (overrides?.RetrievalMode != "Text" && embedding is not null)
if (overrides?.RetrievalMode != RetrievalMode.Text && embedding is not null)
{
embeddings = (await embedding.GenerateEmbeddingAsync(question, cancellationToken: cancellationToken)).ToArray();
}

// step 1
// use llm to get query if retrieval mode is not vector
string? query = null;
if (overrides?.RetrievalMode != "Vector")
if (overrides?.RetrievalMode != RetrievalMode.Vector)
{
var getQueryChat = chat.CreateNewChat(@"You are a helpful AI assistant, generate search query for followup question.
Make your respond simple and precise. Return the query only, do not return any other text.
Expand Down
3 changes: 1 addition & 2 deletions app/shared/Shared/Models/RequestOverrides.cs
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
// Copyright (c) Microsoft. All rights reserved.

namespace Shared.Models;

public record RequestOverrides
{
public bool SemanticRanker { get; set; } = false;

public string RetrievalMode { get; set; } = "Vector"; // available option: Text, Vector, Hybrid
public RetrievalMode RetrievalMode { get; set; } = RetrievalMode.Vector; // available option: Text, Vector, Hybrid

public bool? SemanticCaptions { get; set; }
public string? ExcludeCategory { get; set; }
Expand Down
24 changes: 24 additions & 0 deletions app/shared/Shared/Models/RetrievalMode.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// Copyright (c) Microsoft. All rights reserved.

namespace Shared.Models;

/// <summary>
/// retrieval mode for azure search service
/// </summary>
public enum RetrievalMode
{
/// <summary>
/// Text-only model, where only query will be used to retrieve the results
/// </summary>
Text = 0,

/// <summary>
/// Vector-only model, where only embeddings will be used to retrieve the results
/// </summary>
Vector,

/// <summary>
/// Text + Vector model, where both query and embeddings will be used to retrieve the results
/// </summary>
Hybrid,
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,30 @@

using System.Net.Http.Headers;
using System.Text;
using System.Text.Json;
using Azure.Core;

namespace MinimalApi.Services;

public class AzureComputerVisionService(IHttpClientFactory httpClientFactory, string endPoint, string apiKey)
public class AzureComputerVisionService(HttpClient client, string endPoint, TokenCredential tokenCredential)
{
// add virtual keyword to make it mockable
public virtual async Task<ImageEmbeddingResponse> VectorizeImageAsync(string imagePathOrUrl, CancellationToken ct = default)
public async Task<ImageEmbeddingResponse> VectorizeImageAsync(string imagePathOrUrl, CancellationToken ct = default)
{
var api = $"{endPoint}/computervision/retrieval:vectorizeImage?api-version=2023-02-01-preview&modelVersion=latest";
var token = await tokenCredential.GetTokenAsync(new TokenRequestContext(new[] { "https://cognitiveservices.azure.com/.default" }), ct);
// first try to read as local file
if (File.Exists(imagePathOrUrl))
{
using var request = new HttpRequestMessage(HttpMethod.Post, api);

// set authorization header
request.Headers.Add("Ocp-Apim-Subscription-Key", apiKey);
request.Headers.Add("Authorization", $"Bearer {token.Token}");

// set body
var bytes = await File.ReadAllBytesAsync(imagePathOrUrl, ct);
request.Content = new ByteArrayContent(bytes);
request.Content.Headers.ContentType = new MediaTypeHeaderValue("image/*");

// send request
using var client = httpClientFactory.CreateClient();
using var response = await client.SendAsync(request, ct);
response.EnsureSuccessStatusCode();

Expand All @@ -44,14 +44,13 @@ public virtual async Task<ImageEmbeddingResponse> VectorizeImageAsync(string ima
request.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("application/json"));

// set authorization header
request.Headers.Add("Ocp-Apim-Subscription-Key", apiKey);
request.Headers.Add("Authorization", $"Bearer {token.Token}");

// set body
var body = new { url = imagePathOrUrl };
request.Content = new StringContent(JsonSerializer.Serialize(body), Encoding.UTF8, "application/json");

// send request
using var client = httpClientFactory.CreateClient();
using var response = await client.SendAsync(request, ct);
response.EnsureSuccessStatusCode();

Expand All @@ -67,13 +66,14 @@ public virtual async Task<ImageEmbeddingResponse> VectorizeTextAsync(string text
{
var api = $"{endPoint}/computervision/retrieval:vectorizeText?api-version=2023-02-01-preview&modelVersion=latest";

var token = await tokenCredential.GetTokenAsync(new TokenRequestContext(new[] { "https://cognitiveservices.azure.com/.default" }), ct);
using var request = new HttpRequestMessage(HttpMethod.Post, api);

// set content type to application/json
request.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("application/json"));

// set authorization header
request.Headers.Add("Ocp-Apim-Subscription-Key", apiKey);
request.Headers.Add("Authorization", $"Bearer {token.Token}");

// set body
var body = new { text };
Expand Down
123 changes: 123 additions & 0 deletions app/shared/Shared/Services/AzureDocumentService.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
using Azure;
using Azure.Search.Documents;
using Azure.Search.Documents.Models;
using Shared.Models;

public interface IDocumentService
{
Task<SupportingContentRecord[]> QueryDocumentsAsync(
string? query = null,
float[]? embedding = null,
RequestOverrides? overrides = null,
CancellationToken cancellationToken = default);
}

public class AzureDocumentService(SearchClient searchClient) : IDocumentService
{
public async Task<SupportingContentRecord[]> QueryDocumentsAsync(
string? query = null,
float[]? embedding = null,
RequestOverrides? overrides = null,
CancellationToken cancellationToken = default)
{
if (query is null && embedding is null)
{
throw new ArgumentException("Either query or embedding must be provided");
}

var documentContents = string.Empty;
var top = overrides?.Top ?? 3;
var exclude_category = overrides?.ExcludeCategory;
var filter = exclude_category == null ? string.Empty : $"category ne '{exclude_category}'";
var useSemanticRanker = overrides?.SemanticRanker ?? false;
var useSemanticCaptions = overrides?.SemanticCaptions ?? false;

SearchOptions searchOptions = useSemanticRanker
? new SearchOptions
{
Filter = filter,
QueryType = SearchQueryType.Semantic,
SemanticSearch = new()
{
SemanticConfigurationName = "default",
QueryCaption = new(useSemanticCaptions
? QueryCaptionType.Extractive
: QueryCaptionType.None),
},
// TODO: Find if these options are assignable
//QueryLanguage = "en-us",
//QuerySpeller = "lexicon",
Size = top,
}
: new SearchOptions
{
Filter = filter,
Size = top,
};

if (embedding != null && overrides?.RetrievalMode != RetrievalMode.Text)
{
var k = useSemanticRanker ? 50 : top;
var vectorQuery = new VectorizedQuery(embedding)
{
// if semantic ranker is enabled, we need to set the rank to a large number to get more
// candidates for semantic reranking
KNearestNeighborsCount = useSemanticRanker ? 50 : top,
};
vectorQuery.Fields.Add("embedding");
searchOptions.VectorSearch = new();
searchOptions.VectorSearch.Queries.Add(vectorQuery);
}

var searchResultResponse = await searchClient.SearchAsync<SearchDocument>(
query, searchOptions, cancellationToken);
if (searchResultResponse.Value is null)
{
throw new InvalidOperationException("fail to get search result");
}

SearchResults<SearchDocument> searchResult = searchResultResponse.Value;

// Assemble sources here.
// Example output for each SearchDocument:
// {
// "@search.score": 11.65396,
// "id": "Northwind_Standard_Benefits_Details_pdf-60",
// "content": "x-ray, lab, or imaging service, you will likely be responsible for paying a copayment or coinsurance. The exact amount you will be required to pay will depend on the type of service you receive. You can use the Northwind app or website to look up the cost of a particular service before you receive it.\nIn some cases, the Northwind Standard plan may exclude certain diagnostic x-ray, lab, and imaging services. For example, the plan does not cover any services related to cosmetic treatments or procedures. Additionally, the plan does not cover any services for which no diagnosis is provided.\nIt’s important to note that the Northwind Standard plan does not cover any services related to emergency care. This includes diagnostic x-ray, lab, and imaging services that are needed to diagnose an emergency condition. If you have an emergency condition, you will need to seek care at an emergency room or urgent care facility.\nFinally, if you receive diagnostic x-ray, lab, or imaging services from an out-of-network provider, you may be required to pay the full cost of the service. To ensure that you are receiving services from an in-network provider, you can use the Northwind provider search ",
// "category": null,
// "sourcepage": "Northwind_Standard_Benefits_Details-24.pdf",
// "sourcefile": "Northwind_Standard_Benefits_Details.pdf"
// }
var sb = new List<SupportingContentRecord>();
foreach (var doc in searchResult.GetResults())
{
doc.Document.TryGetValue("sourcepage", out var sourcePageValue);
string? contentValue;
try
{
if (useSemanticCaptions)
{
var docs = doc.SemanticSearch.Captions.Select(c => c.Text);
contentValue = string.Join(" . ", docs);
}
else
{
doc.Document.TryGetValue("content", out var value);
contentValue = (string)value;
}
}
catch (ArgumentNullException)
{
contentValue = null;
}

if (sourcePageValue is string sourcePage && contentValue is string content)
{
content = content.Replace('\r', ' ').Replace('\n', ' ');
sb.Add(new SupportingContentRecord(sourcePage, content));
}
}

return [.. sb];
}
}
13 changes: 12 additions & 1 deletion app/shared/Shared/Shared.csproj
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
<Project Sdk="Microsoft.NET.Sdk">
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<TargetFramework>net8.0</TargetFramework>
Expand All @@ -7,4 +7,15 @@
<LangVersion>preview</LangVersion>
</PropertyGroup>

<ItemGroup>
<PackageReference Include="Azure.AI.FormRecognizer" />
<PackageReference Include="Azure.AI.OpenAI" />
<PackageReference Include="Azure.Extensions.AspNetCore.Configuration.Secrets" />
<PackageReference Include="Azure.Identity" />
<PackageReference Include="Azure.Search.Documents" />
<PackageReference Include="Azure.Storage.Blobs" />
<PackageReference Include="Microsoft.ApplicationInsights.AspNetCore" />
<PackageReference Include="PdfSharpCore" />
</ItemGroup>

</Project>
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ protected EnvironmentSpecificFactAttribute(string skipMessage)
protected abstract bool IsEnvironmentSupported();
}

public sealed class ApiKeyFactAttribute : EnvironmentSpecificFactAttribute
public sealed class EnvironmentVariablesFactAttribute : EnvironmentSpecificFactAttribute
{
private readonly string[] _envVariableNames;

public ApiKeyFactAttribute(params string[] envVariableNames) : base($"{string.Join(", ", envVariableNames)} is not found in env")
public EnvironmentVariablesFactAttribute(params string[] envVariableNames) : base($"{string.Join(", ", envVariableNames)} is not found in env")
{
_envVariableNames = envVariableNames;
}
Expand Down
18 changes: 8 additions & 10 deletions app/tests/MinimalApi.Tests/AzureComputerVisionServiceTest.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
// Copyright (c) Microsoft. All rights reserved.

using Azure.Core;
using Azure.Identity;
using FluentAssertions;
using MinimalApi.Services;
using NSubstitute;
Expand All @@ -8,16 +10,14 @@ namespace MinimalApi.Tests;

public class AzureComputerVisionServiceTest
{
[ApiKeyFact("AZURE_COMPUTER_VISION_API_KEY", "AZURE_COMPUTER_VISION_ENDPOINT")]
[EnvironmentVariablesFact("AZURE_COMPUTER_VISION_ENDPOINT")]
public async Task VectorizeImageTestAsync()
{
var endpoint = Environment.GetEnvironmentVariable("AZURE_COMPUTER_VISION_ENDPOINT") ?? throw new InvalidOperationException();
var apiKey = Environment.GetEnvironmentVariable("AZURE_COMPUTER_VISION_API_KEY") ?? throw new InvalidOperationException();
var httpClientFactory = Substitute.For<IHttpClientFactory>();
httpClientFactory.CreateClient().ReturnsForAnyArgs(x => new HttpClient());
var service = new AzureComputerVisionService(httpClientFactory, endpoint, apiKey);
using var httpClient = new HttpClient();
var imageUrl = @"https://learn.microsoft.com/azure/ai-services/computer-vision/media/quickstarts/presentation.png";

var service = new AzureComputerVisionService(httpClient, endpoint, new DefaultAzureCredential());
var result = await service.VectorizeImageAsync(imageUrl);

result.modelVersion.Should().NotBeNullOrEmpty();
Expand Down Expand Up @@ -45,14 +45,12 @@ public async Task VectorizeImageTestAsync()
}
}

[ApiKeyFact("AZURE_COMPUTER_VISION_API_KEY", "AZURE_COMPUTER_VISION_ENDPOINT")]
[EnvironmentVariablesFact("AZURE_COMPUTER_VISION_ENDPOINT")]
public async Task VectorizeTextTestAsync()
{
var endpoint = Environment.GetEnvironmentVariable("AZURE_COMPUTER_VISION_ENDPOINT") ?? throw new InvalidOperationException();
var apiKey = Environment.GetEnvironmentVariable("AZURE_COMPUTER_VISION_API_KEY") ?? throw new InvalidOperationException();
var httpClientFactory = Substitute.For<IHttpClientFactory>();
httpClientFactory.CreateClient().ReturnsForAnyArgs(x => new HttpClient());
var service = new AzureComputerVisionService(httpClientFactory, endpoint, apiKey);
using var httpClient = new HttpClient();
var service = new AzureComputerVisionService(httpClient, endpoint, new DefaultAzureCredential());
var text = "Hello world";
var result = await service.VectorizeTextAsync(text);

Expand Down
Loading