Skip to content

Commit 9543b54

Browse files
Add test for chat service && pull out document search into an individual service (#270)
## Purpose <!-- Describe the intention of the changes being proposed. What problem does it solve or functionality does it add? --> * ... ## Does this introduce a breaking change? <!-- Mark one with an "x". --> ``` [ ] Yes [ ] No ``` ## Pull Request Type What kind of change does this Pull Request introduce? <!-- Please check the one that applies to this PR using "x". --> ``` [ ] Bugfix [ ] Feature [ ] Code style update (formatting, local variables) [ ] Refactoring (no functional changes, no api changes) [ ] Documentation content changes [ ] Other... Please describe: ``` ## How to Test * Get the code ``` git clone [repo-address] cd [repo-name] git checkout [branch-name] npm install ``` * Test the code <!-- Add steps to run the tests suite and/or manually test --> ``` ``` ## What to Check Verify that the following are valid * ... ## Other Information <!-- Add any other helpful information that may be needed here. --> --------- Co-authored-by: David Pine <[email protected]>
1 parent 9a0d66b commit 9543b54

12 files changed

+328
-35
lines changed

app/backend/Extensions/SearchClientExtensions.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ internal static async Task<SupportingContentRecord[]> QueryDocumentsAsync(
4141
Size = top,
4242
};
4343

44-
if (embedding != null && overrides?.RetrievalMode != "Text")
44+
if (embedding != null && overrides?.RetrievalMode != RetrievalMode.Text)
4545
{
4646
var k = useSemanticRanker ? 50 : top;
4747
var vectorQuery = new VectorizedQuery(embedding)

app/backend/Extensions/ServiceCollectionExtensions.cs

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,18 +27,19 @@ internal static IServiceCollection AddAzureServices(this IServiceCollection serv
2727
return sp.GetRequiredService<BlobServiceClient>().GetBlobContainerClient(azureStorageContainer);
2828
});
2929

30-
services.AddSingleton<SearchClient>(sp =>
30+
services.AddSingleton<IDocumentService, AzureDocumentService>(sp =>
3131
{
3232
var config = sp.GetRequiredService<IConfiguration>();
33-
var (azureSearchServiceEndpoint, azureSearchIndex) =
34-
(config["AzureSearchServiceEndpoint"], config["AzureSearchIndex"]);
35-
33+
var azureSearchServiceEndpoint = config["AzureSearchServiceEndpoint"];
3634
ArgumentNullException.ThrowIfNullOrEmpty(azureSearchServiceEndpoint);
3735

36+
var azureSearchIndex = config["AzureSearchIndex"];
37+
ArgumentNullException.ThrowIfNullOrEmpty(azureSearchIndex);
38+
3839
var searchClient = new SearchClient(
39-
new Uri(azureSearchServiceEndpoint), azureSearchIndex, s_azureCredential);
40+
new Uri(azureSearchServiceEndpoint), azureSearchIndex, s_azureCredential);
4041

41-
return searchClient;
42+
return new AzureDocumentService(searchClient);
4243
});
4344

4445
services.AddSingleton<DocumentAnalysisClient>(sp =>

app/backend/Services/ReadRetrieveReadChatService.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@ namespace MinimalApi.Services;
44

55
public class ReadRetrieveReadChatService
66
{
7-
private readonly SearchClient _searchClient;
7+
private readonly IDocumentService _searchClient;
88
private readonly IKernel _kernel;
99
private readonly IConfiguration _configuration;
1010

1111
public ReadRetrieveReadChatService(
12-
SearchClient searchClient,
12+
IDocumentService searchClient,
1313
OpenAIClient client,
1414
IConfiguration configuration)
1515
{
@@ -45,15 +45,15 @@ public async Task<ApproachResponse> ReplyAsync(
4545
var question = history.LastOrDefault()?.User is { } userQuestion
4646
? userQuestion
4747
: throw new InvalidOperationException("Use question is null");
48-
if (overrides?.RetrievalMode != "Text" && embedding is not null)
48+
if (overrides?.RetrievalMode != RetrievalMode.Text && embedding is not null)
4949
{
5050
embeddings = (await embedding.GenerateEmbeddingAsync(question, cancellationToken: cancellationToken)).ToArray();
5151
}
5252

5353
// step 1
5454
// use llm to get query if retrieval mode is not vector
5555
string? query = null;
56-
if (overrides?.RetrievalMode != "Vector")
56+
if (overrides?.RetrievalMode != RetrievalMode.Vector)
5757
{
5858
var getQueryChat = chat.CreateNewChat(@"You are a helpful AI assistant, generate search query for followup question.
5959
Make your respond simple and precise. Return the query only, do not return any other text.

app/shared/Shared/Models/RequestOverrides.cs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
// Copyright (c) Microsoft. All rights reserved.
22

33
namespace Shared.Models;
4-
54
public record RequestOverrides
65
{
76
public bool SemanticRanker { get; set; } = false;
87

9-
public string RetrievalMode { get; set; } = "Vector"; // available option: Text, Vector, Hybrid
8+
public RetrievalMode RetrievalMode { get; set; } = RetrievalMode.Vector; // available option: Text, Vector, Hybrid
109

1110
public bool? SemanticCaptions { get; set; }
1211
public string? ExcludeCategory { get; set; }
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
// Copyright (c) Microsoft. All rights reserved.
2+
3+
namespace Shared.Models;
4+
5+
/// <summary>
6+
/// retrieval mode for azure search service
7+
/// </summary>
8+
public enum RetrievalMode
9+
{
10+
/// <summary>
11+
/// Text-only model, where only query will be used to retrieve the results
12+
/// </summary>
13+
Text = 0,
14+
15+
/// <summary>
16+
/// Vector-only model, where only embeddings will be used to retrieve the results
17+
/// </summary>
18+
Vector,
19+
20+
/// <summary>
21+
/// Text + Vector model, where both query and embeddings will be used to retrieve the results
22+
/// </summary>
23+
Hybrid,
24+
}

app/backend/Services/AzureComputerVisionService.cs renamed to app/shared/Shared/Services/AzureComputerVisionService.cs

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,30 +2,30 @@
22

33
using System.Net.Http.Headers;
44
using System.Text;
5+
using System.Text.Json;
6+
using Azure.Core;
57

6-
namespace MinimalApi.Services;
7-
8-
public class AzureComputerVisionService(IHttpClientFactory httpClientFactory, string endPoint, string apiKey)
8+
public class AzureComputerVisionService(HttpClient client, string endPoint, TokenCredential tokenCredential)
99
{
1010
// add virtual keyword to make it mockable
11-
public virtual async Task<ImageEmbeddingResponse> VectorizeImageAsync(string imagePathOrUrl, CancellationToken ct = default)
11+
public async Task<ImageEmbeddingResponse> VectorizeImageAsync(string imagePathOrUrl, CancellationToken ct = default)
1212
{
1313
var api = $"{endPoint}/computervision/retrieval:vectorizeImage?api-version=2023-02-01-preview&modelVersion=latest";
14+
var token = await tokenCredential.GetTokenAsync(new TokenRequestContext(new[] { "https://cognitiveservices.azure.com/.default" }), ct);
1415
// first try to read as local file
1516
if (File.Exists(imagePathOrUrl))
1617
{
1718
using var request = new HttpRequestMessage(HttpMethod.Post, api);
1819

1920
// set authorization header
20-
request.Headers.Add("Ocp-Apim-Subscription-Key", apiKey);
21+
request.Headers.Add("Authorization", $"Bearer {token.Token}");
2122

2223
// set body
2324
var bytes = await File.ReadAllBytesAsync(imagePathOrUrl, ct);
2425
request.Content = new ByteArrayContent(bytes);
2526
request.Content.Headers.ContentType = new MediaTypeHeaderValue("image/*");
2627

2728
// send request
28-
using var client = httpClientFactory.CreateClient();
2929
using var response = await client.SendAsync(request, ct);
3030
response.EnsureSuccessStatusCode();
3131

@@ -44,14 +44,13 @@ public virtual async Task<ImageEmbeddingResponse> VectorizeImageAsync(string ima
4444
request.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("application/json"));
4545

4646
// set authorization header
47-
request.Headers.Add("Ocp-Apim-Subscription-Key", apiKey);
47+
request.Headers.Add("Authorization", $"Bearer {token.Token}");
4848

4949
// set body
5050
var body = new { url = imagePathOrUrl };
5151
request.Content = new StringContent(JsonSerializer.Serialize(body), Encoding.UTF8, "application/json");
5252

5353
// send request
54-
using var client = httpClientFactory.CreateClient();
5554
using var response = await client.SendAsync(request, ct);
5655
response.EnsureSuccessStatusCode();
5756

@@ -67,13 +66,14 @@ public virtual async Task<ImageEmbeddingResponse> VectorizeTextAsync(string text
6766
{
6867
var api = $"{endPoint}/computervision/retrieval:vectorizeText?api-version=2023-02-01-preview&modelVersion=latest";
6968

69+
var token = await tokenCredential.GetTokenAsync(new TokenRequestContext(new[] { "https://cognitiveservices.azure.com/.default" }), ct);
7070
using var request = new HttpRequestMessage(HttpMethod.Post, api);
7171

7272
// set content type to application/json
7373
request.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("application/json"));
7474

7575
// set authorization header
76-
request.Headers.Add("Ocp-Apim-Subscription-Key", apiKey);
76+
request.Headers.Add("Authorization", $"Bearer {token.Token}");
7777

7878
// set body
7979
var body = new { text };
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
using Azure;
2+
using Azure.Search.Documents;
3+
using Azure.Search.Documents.Models;
4+
using Shared.Models;
5+
6+
public interface IDocumentService
7+
{
8+
Task<SupportingContentRecord[]> QueryDocumentsAsync(
9+
string? query = null,
10+
float[]? embedding = null,
11+
RequestOverrides? overrides = null,
12+
CancellationToken cancellationToken = default);
13+
}
14+
15+
public class AzureDocumentService(SearchClient searchClient) : IDocumentService
16+
{
17+
public async Task<SupportingContentRecord[]> QueryDocumentsAsync(
18+
string? query = null,
19+
float[]? embedding = null,
20+
RequestOverrides? overrides = null,
21+
CancellationToken cancellationToken = default)
22+
{
23+
if (query is null && embedding is null)
24+
{
25+
throw new ArgumentException("Either query or embedding must be provided");
26+
}
27+
28+
var documentContents = string.Empty;
29+
var top = overrides?.Top ?? 3;
30+
var exclude_category = overrides?.ExcludeCategory;
31+
var filter = exclude_category == null ? string.Empty : $"category ne '{exclude_category}'";
32+
var useSemanticRanker = overrides?.SemanticRanker ?? false;
33+
var useSemanticCaptions = overrides?.SemanticCaptions ?? false;
34+
35+
SearchOptions searchOptions = useSemanticRanker
36+
? new SearchOptions
37+
{
38+
Filter = filter,
39+
QueryType = SearchQueryType.Semantic,
40+
SemanticSearch = new()
41+
{
42+
SemanticConfigurationName = "default",
43+
QueryCaption = new(useSemanticCaptions
44+
? QueryCaptionType.Extractive
45+
: QueryCaptionType.None),
46+
},
47+
// TODO: Find if these options are assignable
48+
//QueryLanguage = "en-us",
49+
//QuerySpeller = "lexicon",
50+
Size = top,
51+
}
52+
: new SearchOptions
53+
{
54+
Filter = filter,
55+
Size = top,
56+
};
57+
58+
if (embedding != null && overrides?.RetrievalMode != RetrievalMode.Text)
59+
{
60+
var k = useSemanticRanker ? 50 : top;
61+
var vectorQuery = new VectorizedQuery(embedding)
62+
{
63+
// if semantic ranker is enabled, we need to set the rank to a large number to get more
64+
// candidates for semantic reranking
65+
KNearestNeighborsCount = useSemanticRanker ? 50 : top,
66+
};
67+
vectorQuery.Fields.Add("embedding");
68+
searchOptions.VectorSearch = new();
69+
searchOptions.VectorSearch.Queries.Add(vectorQuery);
70+
}
71+
72+
var searchResultResponse = await searchClient.SearchAsync<SearchDocument>(
73+
query, searchOptions, cancellationToken);
74+
if (searchResultResponse.Value is null)
75+
{
76+
throw new InvalidOperationException("fail to get search result");
77+
}
78+
79+
SearchResults<SearchDocument> searchResult = searchResultResponse.Value;
80+
81+
// Assemble sources here.
82+
// Example output for each SearchDocument:
83+
// {
84+
// "@search.score": 11.65396,
85+
// "id": "Northwind_Standard_Benefits_Details_pdf-60",
86+
// "content": "x-ray, lab, or imaging service, you will likely be responsible for paying a copayment or coinsurance. The exact amount you will be required to pay will depend on the type of service you receive. You can use the Northwind app or website to look up the cost of a particular service before you receive it.\nIn some cases, the Northwind Standard plan may exclude certain diagnostic x-ray, lab, and imaging services. For example, the plan does not cover any services related to cosmetic treatments or procedures. Additionally, the plan does not cover any services for which no diagnosis is provided.\nIt’s important to note that the Northwind Standard plan does not cover any services related to emergency care. This includes diagnostic x-ray, lab, and imaging services that are needed to diagnose an emergency condition. If you have an emergency condition, you will need to seek care at an emergency room or urgent care facility.\nFinally, if you receive diagnostic x-ray, lab, or imaging services from an out-of-network provider, you may be required to pay the full cost of the service. To ensure that you are receiving services from an in-network provider, you can use the Northwind provider search ",
87+
// "category": null,
88+
// "sourcepage": "Northwind_Standard_Benefits_Details-24.pdf",
89+
// "sourcefile": "Northwind_Standard_Benefits_Details.pdf"
90+
// }
91+
var sb = new List<SupportingContentRecord>();
92+
foreach (var doc in searchResult.GetResults())
93+
{
94+
doc.Document.TryGetValue("sourcepage", out var sourcePageValue);
95+
string? contentValue;
96+
try
97+
{
98+
if (useSemanticCaptions)
99+
{
100+
var docs = doc.SemanticSearch.Captions.Select(c => c.Text);
101+
contentValue = string.Join(" . ", docs);
102+
}
103+
else
104+
{
105+
doc.Document.TryGetValue("content", out var value);
106+
contentValue = (string)value;
107+
}
108+
}
109+
catch (ArgumentNullException)
110+
{
111+
contentValue = null;
112+
}
113+
114+
if (sourcePageValue is string sourcePage && contentValue is string content)
115+
{
116+
content = content.Replace('\r', ' ').Replace('\n', ' ');
117+
sb.Add(new SupportingContentRecord(sourcePage, content));
118+
}
119+
}
120+
121+
return [.. sb];
122+
}
123+
}

app/shared/Shared/Shared.csproj

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
<Project Sdk="Microsoft.NET.Sdk">
1+
<Project Sdk="Microsoft.NET.Sdk">
22

33
<PropertyGroup>
44
<TargetFramework>net8.0</TargetFramework>
@@ -7,4 +7,15 @@
77
<LangVersion>preview</LangVersion>
88
</PropertyGroup>
99

10+
<ItemGroup>
11+
<PackageReference Include="Azure.AI.FormRecognizer" />
12+
<PackageReference Include="Azure.AI.OpenAI" />
13+
<PackageReference Include="Azure.Extensions.AspNetCore.Configuration.Secrets" />
14+
<PackageReference Include="Azure.Identity" />
15+
<PackageReference Include="Azure.Search.Documents" />
16+
<PackageReference Include="Azure.Storage.Blobs" />
17+
<PackageReference Include="Microsoft.ApplicationInsights.AspNetCore" />
18+
<PackageReference Include="PdfSharpCore" />
19+
</ItemGroup>
20+
1021
</Project>

app/tests/MinimalApi.Tests/Attribute/EnvironmentSpecificFactAttribute.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,11 @@ protected EnvironmentSpecificFactAttribute(string skipMessage)
2929
protected abstract bool IsEnvironmentSupported();
3030
}
3131

32-
public sealed class ApiKeyFactAttribute : EnvironmentSpecificFactAttribute
32+
public sealed class EnvironmentVariablesFactAttribute : EnvironmentSpecificFactAttribute
3333
{
3434
private readonly string[] _envVariableNames;
3535

36-
public ApiKeyFactAttribute(params string[] envVariableNames) : base($"{string.Join(", ", envVariableNames)} is not found in env")
36+
public EnvironmentVariablesFactAttribute(params string[] envVariableNames) : base($"{string.Join(", ", envVariableNames)} is not found in env")
3737
{
3838
_envVariableNames = envVariableNames;
3939
}

app/tests/MinimalApi.Tests/AzureComputerVisionServiceTest.cs

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
// Copyright (c) Microsoft. All rights reserved.
22

3+
using Azure.Core;
4+
using Azure.Identity;
35
using FluentAssertions;
46
using MinimalApi.Services;
57
using NSubstitute;
@@ -8,16 +10,14 @@ namespace MinimalApi.Tests;
810

911
public class AzureComputerVisionServiceTest
1012
{
11-
[ApiKeyFact("AZURE_COMPUTER_VISION_API_KEY", "AZURE_COMPUTER_VISION_ENDPOINT")]
13+
[EnvironmentVariablesFact("AZURE_COMPUTER_VISION_ENDPOINT")]
1214
public async Task VectorizeImageTestAsync()
1315
{
1416
var endpoint = Environment.GetEnvironmentVariable("AZURE_COMPUTER_VISION_ENDPOINT") ?? throw new InvalidOperationException();
15-
var apiKey = Environment.GetEnvironmentVariable("AZURE_COMPUTER_VISION_API_KEY") ?? throw new InvalidOperationException();
16-
var httpClientFactory = Substitute.For<IHttpClientFactory>();
17-
httpClientFactory.CreateClient().ReturnsForAnyArgs(x => new HttpClient());
18-
var service = new AzureComputerVisionService(httpClientFactory, endpoint, apiKey);
17+
using var httpClient = new HttpClient();
1918
var imageUrl = @"https://learn.microsoft.com/azure/ai-services/computer-vision/media/quickstarts/presentation.png";
2019

20+
var service = new AzureComputerVisionService(httpClient, endpoint, new DefaultAzureCredential());
2121
var result = await service.VectorizeImageAsync(imageUrl);
2222

2323
result.modelVersion.Should().NotBeNullOrEmpty();
@@ -45,14 +45,12 @@ public async Task VectorizeImageTestAsync()
4545
}
4646
}
4747

48-
[ApiKeyFact("AZURE_COMPUTER_VISION_API_KEY", "AZURE_COMPUTER_VISION_ENDPOINT")]
48+
[EnvironmentVariablesFact("AZURE_COMPUTER_VISION_ENDPOINT")]
4949
public async Task VectorizeTextTestAsync()
5050
{
5151
var endpoint = Environment.GetEnvironmentVariable("AZURE_COMPUTER_VISION_ENDPOINT") ?? throw new InvalidOperationException();
52-
var apiKey = Environment.GetEnvironmentVariable("AZURE_COMPUTER_VISION_API_KEY") ?? throw new InvalidOperationException();
53-
var httpClientFactory = Substitute.For<IHttpClientFactory>();
54-
httpClientFactory.CreateClient().ReturnsForAnyArgs(x => new HttpClient());
55-
var service = new AzureComputerVisionService(httpClientFactory, endpoint, apiKey);
52+
using var httpClient = new HttpClient();
53+
var service = new AzureComputerVisionService(httpClient, endpoint, new DefaultAzureCredential());
5654
var text = "Hello world";
5755
var result = await service.VectorizeTextAsync(text);
5856

0 commit comments

Comments
 (0)