Skip to content

Commit f5b168f

Browse files
add image embedding to azure index (#271)
## Purpose <!-- Describe the intention of the changes being proposed. What problem does it solve or functionality does it add? --> * ... ## Does this introduce a breaking change? <!-- Mark one with an "x". --> ``` [ ] Yes [ ] No ``` ## Pull Request Type What kind of change does this Pull Request introduce? <!-- Please check the one that applies to this PR using "x". --> ``` [ ] Bugfix [ ] Feature [ ] Code style update (formatting, local variables) [ ] Refactoring (no functional changes, no api changes) [ ] Documentation content changes [ ] Other... Please describe: ``` ## How to Test * Get the code ``` git clone [repo-address] cd [repo-name] git checkout [branch-name] npm install ``` * Test the code <!-- Add steps to run the tests suite and/or manually test --> ``` ``` ## What to Check Verify that the following are valid * ... ## Other Information <!-- Add any other helpful information that may be needed here. --> --------- Co-authored-by: David Pine <[email protected]>
1 parent 9543b54 commit f5b168f

37 files changed

+739
-100
lines changed

app/backend/Extensions/ServiceCollectionExtensions.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ internal static IServiceCollection AddAzureServices(this IServiceCollection serv
2727
return sp.GetRequiredService<BlobServiceClient>().GetBlobContainerClient(azureStorageContainer);
2828
});
2929

30-
services.AddSingleton<IDocumentService, AzureDocumentService>(sp =>
30+
services.AddSingleton<ISearchService, AzureSearchService>(sp =>
3131
{
3232
var config = sp.GetRequiredService<IConfiguration>();
3333
var azureSearchServiceEndpoint = config["AzureSearchServiceEndpoint"];
@@ -39,7 +39,7 @@ internal static IServiceCollection AddAzureServices(this IServiceCollection serv
3939
var searchClient = new SearchClient(
4040
new Uri(azureSearchServiceEndpoint), azureSearchIndex, s_azureCredential);
4141

42-
return new AzureDocumentService(searchClient);
42+
return new AzureSearchService(searchClient);
4343
});
4444

4545
services.AddSingleton<DocumentAnalysisClient>(sp =>

app/backend/Services/ReadRetrieveReadChatService.cs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@ namespace MinimalApi.Services;
44

55
public class ReadRetrieveReadChatService
66
{
7-
private readonly IDocumentService _searchClient;
7+
private readonly ISearchService _searchClient;
88
private readonly IKernel _kernel;
99
private readonly IConfiguration _configuration;
1010

1111
public ReadRetrieveReadChatService(
12-
IDocumentService searchClient,
12+
ISearchService searchClient,
1313
OpenAIClient client,
1414
IConfiguration configuration)
1515
{
@@ -158,6 +158,7 @@ Return the follow-up question as a json string list.
158158
}
159159
return new ApproachResponse(
160160
DataPoints: documentContentList,
161+
Images: null,
161162
Answer: ans,
162163
Thoughts: thoughts,
163164
CitationBaseUrl: _configuration.ToCitationBaseUrl());

app/frontend/Services/ApiClient.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ private async Task<AnswerResult<TRequest>> PostRequestAsync<TRequest>(
125125
$"HTTP {(int)response.StatusCode} : {response.ReasonPhrase ?? "☹️ Unknown error..."}",
126126
null,
127127
[],
128+
null,
128129
"Unable to retrieve valid response from the server.");
129130

130131
return result with

app/functions/EmbedFunctions/Program.cs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,17 @@ uri is not null
6565
var documentClient = provider.GetRequiredService<DocumentAnalysisClient>();
6666
var logger = provider.GetRequiredService<ILogger<AzureSearchEmbedService>>();
6767

68-
return new AzureSearchEmbedService(openAIClient, embeddingModelName, searchClient, searchIndexName, searchIndexClient, documentClient, blobContainerClient, logger);
68+
return new AzureSearchEmbedService(
69+
openAIClient: openAIClient,
70+
embeddingModelName: embeddingModelName,
71+
searchClient: searchClient,
72+
searchIndexName: searchIndexName,
73+
searchIndexClient: searchIndexClient,
74+
documentAnalysisClient: documentClient,
75+
corpusContainerClient: blobContainerClient,
76+
computerVisionService: null,
77+
includeImageEmbeddingsField: false,
78+
logger: logger);
6979
});
7080
})
7181
.ConfigureFunctionsWorkerDefaults()

app/functions/EmbedFunctions/Services/EmbeddingAggregateService.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ internal async Task EmbedBlobAsync(Stream blobStream, string blobName)
1414
var embeddingType = GetEmbeddingType();
1515
var embedService = embedServiceFactory.GetEmbedService(embeddingType);
1616

17-
var result = await embedService.EmbedBlobAsync(blobStream, blobName);
17+
var result = await embedService.EmbedPDFBlobAsync(blobStream, blobName);
1818

1919
var status = result switch
2020
{

app/functions/EmbedFunctions/Services/IEmbedService.cs

Lines changed: 0 additions & 19 deletions
This file was deleted.
Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,24 @@
11
// Copyright (c) Microsoft. All rights reserved.
22

3+
34
namespace EmbedFunctions.Services;
45

56
internal sealed class MilvusEmbedService : IEmbedService
67
{
7-
public Task<bool> EmbedBlobAsync(Stream blobStream, string blobName) => throw new NotImplementedException();
8+
public Task CreateSearchIndexAsync(string searchIndexName, CancellationToken ct = default)
9+
{
10+
throw new NotImplementedException();
11+
}
12+
13+
public Task<bool> EmbedImageBlobAsync(Stream imageStream, string imageName, CancellationToken ct = default)
14+
{
15+
throw new NotImplementedException();
16+
}
17+
18+
public Task<bool> EmbedPDFBlobAsync(Stream blobStream, string blobName) => throw new NotImplementedException();
19+
20+
public Task EnsureSearchIndexAsync(string searchIndexName, CancellationToken ct = default)
21+
{
22+
throw new NotImplementedException();
23+
}
824
}
Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,25 @@
11
// Copyright (c) Microsoft. All rights reserved.
22

3+
34
namespace EmbedFunctions.Services;
45

56
internal sealed class PineconeEmbedService : IEmbedService
67
{
7-
public Task<bool> EmbedBlobAsync(Stream blobStream, string blobName) => throw new NotImplementedException(
8+
public Task CreateSearchIndexAsync(string searchIndexName, CancellationToken ct = default)
9+
{
10+
throw new NotImplementedException();
11+
}
12+
13+
public Task<bool> EmbedImageBlobAsync(Stream imageStream, string imageName, CancellationToken ct = default)
14+
{
15+
throw new NotImplementedException();
16+
}
17+
18+
public Task<bool> EmbedPDFBlobAsync(Stream blobStream, string blobName) => throw new NotImplementedException(
819
"Pinecone embedding isn't implemented.");
20+
21+
public Task EnsureSearchIndexAsync(string searchIndexName, CancellationToken ct = default)
22+
{
23+
throw new NotImplementedException();
24+
}
925
}
Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,24 @@
11
// Copyright (c) Microsoft. All rights reserved.
22

3+
34
namespace EmbedFunctions.Services;
45

56
internal sealed class QdrantEmbedService : IEmbedService
67
{
7-
public Task<bool> EmbedBlobAsync(Stream blobStream, string blobName) => throw new NotImplementedException();
8+
public Task CreateSearchIndexAsync(string searchIndexName, CancellationToken ct = default)
9+
{
10+
throw new NotImplementedException();
11+
}
12+
13+
public Task<bool> EmbedImageBlobAsync(Stream imageStream, string imageName, CancellationToken ct = default)
14+
{
15+
throw new NotImplementedException();
16+
}
17+
18+
public Task<bool> EmbedPDFBlobAsync(Stream blobStream, string blobName) => throw new NotImplementedException();
19+
20+
public Task EnsureSearchIndexAsync(string searchIndexName, CancellationToken ct = default)
21+
{
22+
throw new NotImplementedException();
23+
}
824
}

app/prepdocs/PrepareDocs/AppOptions.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ internal record class AppOptions(
1616
bool Remove,
1717
bool RemoveAll,
1818
string? FormRecognizerServiceEndpoint,
19+
string? ComputerVisionServiceEndpoint,
1920
bool Verbose,
2021
IConsole Console) : AppConsole(Console);
2122

app/prepdocs/PrepareDocs/PrepareDocs.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
</ItemGroup>
2121

2222
<ItemGroup>
23-
<ProjectReference Include="..\..\functions\EmbedFunctions\EmbedFunctions.csproj" />
23+
<ProjectReference Include="..\..\shared\Shared\Shared.csproj" />
2424
</ItemGroup>
2525

2626
</Project>

app/prepdocs/PrepareDocs/Program.Clients.cs

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11
// Copyright (c) Microsoft. All rights reserved.
22

3-
4-
using EmbedFunctions.Services;
5-
63
internal static partial class Program
74
{
85
private static BlobContainerClient? s_corpusContainerClient;
@@ -30,8 +27,19 @@ private static Task<AzureSearchEmbedService> GetAzureSearchEmbedService(AppOptio
3027
var openAIClient = await GetAzureOpenAIClientAsync(o);
3128
var embeddingModelName = o.EmbeddingModelName ?? throw new ArgumentNullException(nameof(o.EmbeddingModelName));
3229
var searchIndexName = o.SearchIndexName ?? throw new ArgumentNullException(nameof(o.SearchIndexName));
33-
34-
return new AzureSearchEmbedService(openAIClient, embeddingModelName, searchClient, searchIndexName, searchIndexClient, documentClient, blobContainerClient, null);
30+
var computerVisionService = await GetComputerVisionServiceAsync(o);
31+
32+
return new AzureSearchEmbedService(
33+
openAIClient: openAIClient,
34+
embeddingModelName: embeddingModelName,
35+
searchClient: searchClient,
36+
searchIndexName: searchIndexName,
37+
searchIndexClient: searchIndexClient,
38+
documentAnalysisClient: documentClient,
39+
corpusContainerClient: blobContainerClient,
40+
computerVisionService: computerVisionService,
41+
includeImageEmbeddingsField: computerVisionService != null,
42+
logger: null);
3543
});
3644

3745
private static Task<BlobContainerClient> GetCorpusBlobContainerClientAsync(AppOptions options) =>
@@ -139,6 +147,20 @@ private static Task<SearchClient> GetSearchClientAsync(AppOptions options) =>
139147
return s_searchClient;
140148
});
141149

150+
private static Task<IComputerVisionService?> GetComputerVisionServiceAsync(AppOptions options) =>
151+
GetLazyClientAsync<IComputerVisionService?>(options, s_openAILock, async o =>
152+
{
153+
await Task.CompletedTask;
154+
var endpoint = o.ComputerVisionServiceEndpoint;
155+
156+
if (string.IsNullOrEmpty(endpoint))
157+
{
158+
return null;
159+
}
160+
161+
return new AzureComputerVisionService(new HttpClient(), endpoint, DefaultCredential);
162+
});
163+
142164
private static Task<OpenAIClient> GetAzureOpenAIClientAsync(AppOptions options) =>
143165
GetLazyClientAsync<OpenAIClient>(options, s_openAILock, async o =>
144166
{

app/prepdocs/PrepareDocs/Program.Options.cs

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ internal static partial class Program
4141
private static readonly Option<string> s_formRecognizerServiceEndpoint =
4242
new(name: "--formrecognizerendpoint", description: "Optional. The Azure Form Recognizer service endpoint which will be used to extract text, tables and layout from the documents (must exist already)");
4343

44+
private static readonly Option<string> s_computerVisionServiceEndpoint =
45+
new(name: "--computervisionendpoint", description: "Optional. The Azure Computer Vision service endpoint which will be used to vectorize image and query");
46+
4447
private static readonly Option<bool> s_verbose =
4548
new(aliases: new[] { "--verbose", "-v" }, description: "Verbose output");
4649

@@ -49,11 +52,23 @@ internal static partial class Program
4952
Prepare documents by extracting content from PDFs, splitting content into sections,
5053
uploading to blob storage, and indexing in a search index.
5154
""")
52-
{
53-
s_files, s_category, s_skipBlobs, s_storageEndpoint,
54-
s_container, s_tenantId, s_searchService, s_searchIndexName, s_azureOpenAIService, s_embeddingModelName,
55-
s_remove, s_removeAll, s_formRecognizerServiceEndpoint, s_verbose
56-
};
55+
{
56+
s_files,
57+
s_category,
58+
s_skipBlobs,
59+
s_storageEndpoint,
60+
s_container,
61+
s_tenantId,
62+
s_searchService,
63+
s_searchIndexName,
64+
s_azureOpenAIService,
65+
s_embeddingModelName,
66+
s_remove,
67+
s_removeAll,
68+
s_formRecognizerServiceEndpoint,
69+
s_computerVisionServiceEndpoint,
70+
s_verbose,
71+
};
5772

5873
private static AppOptions GetParsedAppOptions(InvocationContext context) => new(
5974
Files: context.ParseResult.GetValueForArgument(s_files),
@@ -69,6 +84,7 @@ internal static partial class Program
6984
Remove: context.ParseResult.GetValueForOption(s_remove),
7085
RemoveAll: context.ParseResult.GetValueForOption(s_removeAll),
7186
FormRecognizerServiceEndpoint: context.ParseResult.GetValueForOption(s_formRecognizerServiceEndpoint),
87+
ComputerVisionServiceEndpoint: context.ParseResult.GetValueForOption(s_computerVisionServiceEndpoint),
7288
Verbose: context.ParseResult.GetValueForOption(s_verbose),
7389
Console: context.Console);
7490
}

app/prepdocs/PrepareDocs/Program.cs

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
// Copyright (c) Microsoft. All rights reserved.
22

3-
using EmbedFunctions.Services;
3+
using System.Diagnostics;
44

55
s_rootCommand.SetHandler(
66
async (context) =>
@@ -15,7 +15,6 @@
1515
{
1616
var searchIndexName = options.SearchIndexName ?? throw new ArgumentNullException(nameof(options.SearchIndexName));
1717
var embedService = await GetAzureSearchEmbedService(options);
18-
1918
await embedService.EnsureSearchIndexAsync(options.SearchIndexName);
2019

2120
Matcher matcher = new();
@@ -190,19 +189,26 @@ static async ValueTask UploadBlobsAndCreateIndexAsync(
190189
// revert stream position
191190
stream.Position = 0;
192191

193-
await embeddingService.EmbedBlobAsync(stream, documentName);
192+
await embeddingService.EmbedPDFBlobAsync(stream, documentName);
194193
}
195194
finally
196195
{
197196
File.Delete(tempFileName);
198197
}
199198
}
200199
}
200+
// if it's an img (end with .png/.jpg/.jpeg), upload it to blob storage and embed it.
201+
else if (Path.GetExtension(fileName).Equals(".png", StringComparison.OrdinalIgnoreCase) ||
202+
Path.GetExtension(fileName).Equals(".jpg", StringComparison.OrdinalIgnoreCase) ||
203+
Path.GetExtension(fileName).Equals(".jpeg", StringComparison.OrdinalIgnoreCase))
204+
{
205+
await embeddingService.EmbedImageBlobAsync(File.OpenRead(fileName), fileName);
206+
}
201207
else
202208
{
203209
var blobName = BlobNameFromFilePage(fileName);
204210
await UploadBlobAsync(fileName, blobName, container);
205-
await embeddingService.EmbedBlobAsync(File.OpenRead(fileName), blobName);
211+
await embeddingService.EmbedPDFBlobAsync(File.OpenRead(fileName), blobName);
206212
}
207213
}
208214

app/shared/Shared/Models/ApproachResponse.cs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,13 @@
33
namespace Shared.Models;
44

55
public record SupportingContentRecord(string Title, string Content);
6+
7+
public record SupportingImageRecord(string Title, string Url);
8+
69
public record ApproachResponse(
710
string Answer,
811
string? Thoughts,
9-
SupportingContentRecord[] DataPoints, // title, content
12+
SupportingContentRecord[]? DataPoints, // title, content
13+
SupportingImageRecord[]? Images, // title, url
1014
string CitationBaseUrl,
1115
string? Error = null);

app/shared/Shared/Services/AzureComputerVisionService.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@
55
using System.Text.Json;
66
using Azure.Core;
77

8-
public class AzureComputerVisionService(HttpClient client, string endPoint, TokenCredential tokenCredential)
8+
public class AzureComputerVisionService(HttpClient client, string endPoint, TokenCredential tokenCredential) : IComputerVisionService
99
{
10+
public int Dimension => 1024;
11+
1012
// add virtual keyword to make it mockable
1113
public async Task<ImageEmbeddingResponse> VectorizeImageAsync(string imagePathOrUrl, CancellationToken ct = default)
1214
{

0 commit comments

Comments
 (0)