Skip to content

Add GPT-4V support in ChatService #272

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions app/Directory.Packages.props
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
</PropertyGroup>
<ItemGroup>
<PackageVersion Include="Azure.AI.FormRecognizer" Version="4.1.0" />
<PackageVersion Include="Azure.AI.OpenAI" Version="1.0.0-beta.8" />
<PackageVersion Include="Azure.AI.OpenAI" Version="1.0.0-beta.12" />
<PackageVersion Include="Azure.Extensions.AspNetCore.Configuration.Secrets" Version="1.3.0" />
<PackageVersion Include="Azure.Identity" Version="1.10.4" />
<PackageVersion Include="Azure.Search.Documents" Version="11.5.1" />
Expand Down Expand Up @@ -33,7 +33,7 @@
<PackageVersion Include="Microsoft.Extensions.Options.ConfigurationExtensions" Version="8.0.0" />
<PackageVersion Include="Microsoft.ML" Version="3.0.0" />
<PackageVersion Include="Microsoft.NET.Test.Sdk" Version="17.8.0" />
<PackageVersion Include="Microsoft.SemanticKernel" Version="0.24.230918.1-preview" />
<PackageVersion Include="Microsoft.SemanticKernel" Version="1.3.0" />
<PackageVersion Include="Microsoft.VisualStudio.Azure.Containers.Tools.Targets" Version="1.19.5" />
<PackageVersion Include="MudBlazor" Version="6.11.1" />
<PackageVersion Include="PdfSharpCore" Version="1.3.62" />
Expand Down
26 changes: 9 additions & 17 deletions app/backend/Extensions/WebApplicationExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -44,36 +44,28 @@ private static async IAsyncEnumerable<ChatChunkResponse> OnPostChatPromptAsync(
{
var deploymentId = config["AZURE_OPENAI_CHATGPT_DEPLOYMENT"];
var response = await client.GetChatCompletionsStreamingAsync(
deploymentId, new ChatCompletionsOptions
new ChatCompletionsOptions
{
DeploymentName = deploymentId,
Messages =
{
new ChatMessage(ChatRole.System, """
new ChatRequestSystemMessage("""
You're an AI assistant for developers, helping them write code more efficiently.
You're name is **Blazor 📎 Clippy** and you're an expert Blazor developer.
You're also an expert in ASP.NET Core, C#, TypeScript, and even JavaScript.
You will always reply with a Markdown formatted response.
"""),

new ChatMessage(ChatRole.User, "What's your name?"),

new ChatMessage(ChatRole.Assistant,
"Hi, my name is **Blazor 📎 Clippy**! Nice to meet you."),

new ChatMessage(ChatRole.User, prompt.Prompt)
new ChatRequestUserMessage("What's your name?"),
new ChatRequestAssistantMessage("Hi, my name is **Blazor 📎 Clippy**! Nice to meet you."),
new ChatRequestUserMessage(prompt.Prompt)
}
}, cancellationToken);

using var completions = response.Value;
await foreach (var choice in completions.GetChoicesStreaming(cancellationToken))
await foreach (var choice in response.WithCancellation(cancellationToken))
{
await foreach (var message in choice.GetMessageStreaming(cancellationToken))
if (choice.ContentUpdate is { Length: > 0 })
{
if (message is { Content.Length: > 0 })
{
var (length, content) = (message.Content.Length, message.Content);
yield return new ChatChunkResponse(length, content);
}
yield return new ChatChunkResponse(choice.ContentUpdate.Length, choice.ContentUpdate);
}
}
}
Expand Down
2 changes: 0 additions & 2 deletions app/backend/GlobalUsings.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
global using Microsoft.AspNetCore.Mvc;
global using Microsoft.AspNetCore.Mvc.RazorPages;
global using Microsoft.SemanticKernel;
global using Microsoft.SemanticKernel.AI.ChatCompletion;
global using Microsoft.SemanticKernel.AI.Embeddings;
global using MinimalApi.Extensions;
global using MinimalApi.Services;
global using PdfSharpCore.Pdf;
Expand Down
131 changes: 97 additions & 34 deletions app/backend/Services/ReadRetrieveReadChatService.cs
Original file line number Diff line number Diff line change
@@ -1,32 +1,59 @@
// Copyright (c) Microsoft. All rights reserved.

namespace MinimalApi.Services;
using Azure.Core;
using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.Connectors.OpenAI;
using Microsoft.SemanticKernel.Embeddings;

namespace MinimalApi.Services;
#pragma warning disable SKEXP0011 // Mark members as static
#pragma warning disable SKEXP0001 // Mark members as static
public class ReadRetrieveReadChatService
{
private readonly ISearchService _searchClient;
private readonly IKernel _kernel;
private readonly Kernel _kernel;
private readonly IConfiguration _configuration;
private readonly IComputerVisionService? _visionService;
private readonly TokenCredential? _tokenCredential;

public ReadRetrieveReadChatService(
ISearchService searchClient,
OpenAIClient client,
IConfiguration configuration)
IConfiguration configuration,
IComputerVisionService? visionService = null,
TokenCredential? tokenCredential = null)
{
_searchClient = searchClient;
var deployedModelName = configuration["AzureOpenAiChatGptDeployment"];
ArgumentNullException.ThrowIfNullOrWhiteSpace(deployedModelName);
var kernelBuilder = Kernel.CreateBuilder();

var kernelBuilder = Kernel.Builder.WithAzureChatCompletionService(deployedModelName, client);
var embeddingModelName = configuration["AzureOpenAiEmbeddingDeployment"];
if (!string.IsNullOrEmpty(embeddingModelName))
if (configuration["UseAOAI"] != "true")
{
var endpoint = configuration["AzureOpenAiServiceEndpoint"];
ArgumentNullException.ThrowIfNullOrWhiteSpace(endpoint);
kernelBuilder = kernelBuilder.WithAzureTextEmbeddingGenerationService(embeddingModelName, endpoint, new DefaultAzureCredential());
var deployment = configuration["OpenAiChatGptDeployment"];
ArgumentNullException.ThrowIfNullOrWhiteSpace(deployment);
kernelBuilder = kernelBuilder.AddOpenAIChatCompletion(deployment, client);

var embeddingModelName = configuration["OpenAiEmbeddingDeployment"];
ArgumentNullException.ThrowIfNullOrWhiteSpace(embeddingModelName);
kernelBuilder = kernelBuilder.AddOpenAITextEmbeddingGeneration(embeddingModelName, client);
}
else
{
var deployedModelName = configuration["AzureOpenAiChatGptDeployment"];
ArgumentNullException.ThrowIfNullOrWhiteSpace(deployedModelName);
var embeddingModelName = configuration["AzureOpenAiEmbeddingDeployment"];
if (!string.IsNullOrEmpty(embeddingModelName))
{
var endpoint = configuration["AzureOpenAiServiceEndpoint"];
ArgumentNullException.ThrowIfNullOrWhiteSpace(endpoint);
kernelBuilder = kernelBuilder.AddAzureOpenAITextEmbeddingGeneration(embeddingModelName, endpoint, tokenCredential ?? new DefaultAzureCredential());
kernelBuilder = kernelBuilder.AddAzureOpenAIChatCompletion(deployedModelName, endpoint, tokenCredential ?? new DefaultAzureCredential());
}
}

_kernel = kernelBuilder.Build();
_configuration = configuration;
_visionService = visionService;
_tokenCredential = tokenCredential;
}

public async Task<ApproachResponse> ReplyAsync(
Expand All @@ -39,8 +66,8 @@ public async Task<ApproachResponse> ReplyAsync(
var useSemanticRanker = overrides?.SemanticRanker ?? false;
var excludeCategory = overrides?.ExcludeCategory ?? null;
var filter = excludeCategory is null ? null : $"category ne '{excludeCategory}'";
IChatCompletion chat = _kernel.GetService<IChatCompletion>();
ITextEmbeddingGeneration? embedding = _kernel.GetService<ITextEmbeddingGeneration>();
var chat = _kernel.GetRequiredService<IChatCompletionService>();
var embedding = _kernel.GetRequiredService<ITextEmbeddingGenerationService>();
float[]? embeddings = null;
var question = history.LastOrDefault()?.User is { } userQuestion
? userQuestion
Expand All @@ -55,24 +82,19 @@ public async Task<ApproachResponse> ReplyAsync(
string? query = null;
if (overrides?.RetrievalMode != RetrievalMode.Vector)
{
var getQueryChat = chat.CreateNewChat(@"You are a helpful AI assistant, generate search query for followup question.
var getQueryChat = new ChatHistory(@"You are a helpful AI assistant, generate search query for followup question.
Make your respond simple and precise. Return the query only, do not return any other text.
e.g.
Northwind Health Plus AND standard plan.
standard plan AND dental AND employee benefit.
");

getQueryChat.AddUserMessage(question);
var result = await chat.GetChatCompletionsAsync(
var result = await chat.GetChatMessageContentAsync(
getQueryChat,
cancellationToken: cancellationToken);

if (result.Count != 1)
{
throw new InvalidOperationException("Failed to get search query");
}

query = result[0].ModelResult.GetOpenAIChatResult().Choice.Message.Content;
query = result.Content ?? throw new InvalidOperationException("Failed to get search query");
}

// step 2
Expand All @@ -89,12 +111,19 @@ standard plan AND dental AND employee benefit.
documentContents = string.Join("\r", documentContentList.Select(x =>$"{x.Title}:{x.Content}"));
}

Console.WriteLine(documentContents);
// step 2.5
// retrieve images if _visionService is available
SupportingImageRecord[]? images = default;
if (_visionService is not null)
{
var queryEmbeddings = await _visionService.VectorizeTextAsync(query ?? question, cancellationToken);
images = await _searchClient.QueryImagesAsync(query, queryEmbeddings.vector, overrides, cancellationToken);
}

// step 3
// put together related docs and conversation history to generate answer
var answerChat = chat.CreateNewChat(
"You are a system assistant who helps the company employees with their healthcare " +
"plan questions, and questions about the employee handbook. Be brief in your answers");
var answerChat = new ChatHistory(
"You are a system assistant who helps the company employees with their questions. Be brief in your answers");

// add chat history
foreach (var turn in history)
Expand All @@ -106,22 +135,56 @@ standard plan AND dental AND employee benefit.
}
}

// format prompt
answerChat.AddUserMessage(@$" ## Source ##

if (images != null)
{
var prompt = @$"## Source ##
{documentContents}
## End ##

Answer question based on available source and images.
Your answer needs to be a json object with answer and thoughts field.
Don't put your answer between ```json and ```, return the json string directly. e.g {{""answer"": ""I don't know"", ""thoughts"": ""I don't know""}}";

var tokenRequestContext = new TokenRequestContext(new[] { "https://storage.azure.com/.default" });
var sasToken = await (_tokenCredential?.GetTokenAsync(tokenRequestContext, cancellationToken) ?? throw new InvalidOperationException("Failed to get token"));
var sasTokenString = sasToken.Token;
var imageUrls = images.Select(x => $"{x.Url}?{sasTokenString}").ToArray();
var collection = new ChatMessageContentItemCollection();
collection.Add(new TextContent(prompt));
foreach (var imageUrl in imageUrls)
{
collection.Add(new ImageContent(new Uri(imageUrl)));
}

answerChat.AddUserMessage(collection);
}
else
{
var prompt = @$" ## Source ##
{documentContents}
## End ##

You answer needs to be a json object with the following format.
{{
""answer"": // the answer to the question, add a source reference to the end of each sentence. e.g. Apple is a fruit [reference1.pdf][reference2.pdf]. If no source available, put the answer as I don't know.
""thoughts"": // brief thoughts on how you came up with the answer, e.g. what sources you used, what you thought about, etc.
}}");
}}";
answerChat.AddUserMessage(prompt);
}

var promptExecutingSetting = new OpenAIPromptExecutionSettings
{
MaxTokens = 1024,
Temperature = overrides?.Temperature ?? 0.7,
};

// get answer
var answer = await chat.GetChatCompletionsAsync(
var answer = await chat.GetChatMessageContentAsync(
answerChat,
promptExecutingSetting,
cancellationToken: cancellationToken);
var answerJson = answer[0].ModelResult.GetOpenAIChatResult().Choice.Message.Content;
var answerJson = answer.Content ?? throw new InvalidOperationException("Failed to get search query");
var answerObject = JsonSerializer.Deserialize<JsonElement>(answerJson);
var ans = answerObject.GetProperty("answer").GetString() ?? throw new InvalidOperationException("Failed to get answer");
var thoughts = answerObject.GetProperty("thoughts").GetString() ?? throw new InvalidOperationException("Failed to get thoughts");
Expand All @@ -130,7 +193,7 @@ You answer needs to be a json object with the following format.
// add follow up questions if requested
if (overrides?.SuggestFollowupQuestions is true)
{
var followUpQuestionChat = chat.CreateNewChat(@"You are a helpful AI assistant");
var followUpQuestionChat = new ChatHistory(@"You are a helpful AI assistant");
followUpQuestionChat.AddUserMessage($@"Generate three follow-up question based on the answer you just generated.
# Answer
{ans}
Expand All @@ -144,11 +207,11 @@ Return the follow-up question as a json string list.
""What is the out-of-pocket maximum?""
]");

var followUpQuestions = await chat.GetChatCompletionsAsync(
var followUpQuestions = await chat.GetChatMessageContentAsync(
followUpQuestionChat,
cancellationToken: cancellationToken);

var followUpQuestionsJson = followUpQuestions[0].ModelResult.GetOpenAIChatResult().Choice.Message.Content;
var followUpQuestionsJson = followUpQuestions.Content ?? throw new InvalidOperationException("Failed to get search query");
var followUpQuestionsObject = JsonSerializer.Deserialize<JsonElement>(followUpQuestionsJson);
var followUpQuestionsList = followUpQuestionsObject.EnumerateArray().Select(x => x.GetString()).ToList();
foreach (var followUpQuestion in followUpQuestionsList)
Expand All @@ -158,7 +221,7 @@ Return the follow-up question as a json string list.
}
return new ApproachResponse(
DataPoints: documentContentList,
Images: null,
Images: images,
Answer: ans,
Thoughts: thoughts,
CitationBaseUrl: _configuration.ToCitationBaseUrl());
Expand Down
5 changes: 4 additions & 1 deletion app/prepdocs/PrepareDocs/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,10 @@ static async ValueTask UploadBlobsAndCreateIndexAsync(
Path.GetExtension(fileName).Equals(".jpg", StringComparison.OrdinalIgnoreCase) ||
Path.GetExtension(fileName).Equals(".jpeg", StringComparison.OrdinalIgnoreCase))
{
await embeddingService.EmbedImageBlobAsync(File.OpenRead(fileName), fileName);
using var stream = File.OpenRead(fileName);
var blobName = BlobNameFromFilePage(fileName);
await UploadBlobAsync(fileName, blobName, container);
await embeddingService.EmbedImageBlobAsync(stream, fileName);
}
else
{
Expand Down
39 changes: 9 additions & 30 deletions app/shared/Shared/Services/AzureSearchEmbedService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -73,53 +73,32 @@ Indexing sections from '{BlobName}' into search index '{SearchIndexName}'
}
}

public async Task<bool> EmbedImageBlobAsync(Stream imageStream, string imageName, CancellationToken ct = default)
public async Task<bool> EmbedImageBlobAsync(
Stream imageStream,
string imageUrl,
CancellationToken ct = default)
{
if (includeImageEmbeddingsField == false || computerVisionService is null)
{
throw new InvalidOperationException(
"Computer Vision service is required to include image embeddings field");
}

// step 1
// upload image to blob storage
var blobClient = corpusContainerClient.GetBlobClient(imageName);
if (await blobClient.ExistsAsync())
{
logger?.LogWarning("Blob '{BlobName}' already exists", imageName);
}
else
{
logger?.LogInformation("Uploading image '{ImageName}'", imageName);
await blobClient.UploadAsync(imageStream, new BlobHttpHeaders
{
ContentType = "image"
});
}

// step 2
// get image embeddings
imageStream.Position = 0;
var tempPath = Path.GetTempFileName();
await using var tempStream = File.OpenWrite(tempPath);
await imageStream.CopyToAsync(tempStream, ct);
tempStream.Close();

var embeddings = await computerVisionService.VectorizeImageAsync(tempPath, ct);
var embeddings = await computerVisionService.VectorizeImageAsync(imageUrl, ct);

// id can only contain letters, digits, underscore (_), dash (-), or equal sign (=).
var imageId = MatchInSetRegex().Replace(imageName, "_").TrimStart('_');
var imageId = MatchInSetRegex().Replace(imageUrl, "_").TrimStart('_');
// step 3
// index image embeddings
var indexAction = new IndexDocumentsAction<SearchDocument>(
IndexActionType.MergeOrUpload,
new SearchDocument
{
["id"] = imageId,
["content"] = imageName,
["content"] = imageUrl,
["category"] = "image",
["imageEmbedding"] = embeddings.vector,
["sourcefile"] = blobClient.Uri.ToString(),
["sourcefile"] = imageUrl,
});

var batch = new IndexDocumentsBatch<SearchDocument>();
Expand Down Expand Up @@ -469,7 +448,7 @@ private async Task IndexSectionsAsync(IEnumerable<Section> sections)
var batch = new IndexDocumentsBatch<SearchDocument>();
foreach (var section in sections)
{
var embeddings = await openAIClient.GetEmbeddingsAsync(embeddingModelName, new Azure.AI.OpenAI.EmbeddingsOptions(section.Content.Replace('\r', ' ')));
var embeddings = await openAIClient.GetEmbeddingsAsync(new Azure.AI.OpenAI.EmbeddingsOptions(embeddingModelName, [section.Content.Replace('\r', ' ')]));
var embedding = embeddings.Value.Data.FirstOrDefault()?.Embedding.ToArray() ?? [];
batch.Actions.Add(new IndexDocumentsAction<SearchDocument>(
IndexActionType.MergeOrUpload,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public async Task QueryDocumentsTestEmbeddingOnlyAsync()
var openAiEmbeddingDeployment = Environment.GetEnvironmentVariable("AZURE_OPENAI_EMBEDDING_DEPLOYMENT") ?? throw new InvalidOperationException();
var openAIClient = new OpenAIClient(new Uri(openAiEndpoint), new DefaultAzureCredential());
var query = "What is included in my Northwind Health Plus plan that is not in standard?";
var embeddingResponse = await openAIClient.GetEmbeddingsAsync(openAiEmbeddingDeployment, new EmbeddingsOptions(query));
var embeddingResponse = await openAIClient.GetEmbeddingsAsync(new EmbeddingsOptions(openAiEmbeddingDeployment, [query]));
var embedding = embeddingResponse.Value.Data.First().Embedding;
var searchClient = new SearchClient(new Uri(searchServceEndpoint), index, new DefaultAzureCredential());
var service = new AzureSearchService(searchClient);
Expand Down
Loading