-
Notifications
You must be signed in to change notification settings - Fork 25.3k
Add Mistral AI Chat Completion support to Inference Plugin #128538
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Mistral AI Chat Completion support to Inference Plugin #128538
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking good. I left a few suggestions.
Could you update the description of the PR so that the example requests are formatted?
Let's also wrap them in code blocks using the three backticks. Thanks for making them collapsable sections though!
builder.startObject(STREAM_OPTIONS_FIELD); | ||
builder.field(INCLUDE_USAGE_FIELD, true); | ||
builder.endObject(); | ||
fillStreamOptionsFields(builder); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just an FYI we have some inflight changes that'll affect how we do this: #128592
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the heads up! I inspected the changes in the attached PR. They don't affect my changes. We're good.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry what I meant is that we can use the same approach here. Let's leverage the Params
class to determine whether to serialize the stream options, the same way that the PR I listed is doing. That way we don't need to subclass.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
...ference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java
Show resolved
Hide resolved
.../org/elasticsearch/xpack/inference/services/mistral/response/MistralErrorResponseEntity.java
Outdated
Show resolved
Hide resolved
...e/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiResponseHandler.java
Show resolved
Hide resolved
import org.elasticsearch.xpack.inference.services.mistral.response.MistralErrorResponseEntity; | ||
|
||
/** | ||
* Handles non-streaming chat completion responses for Mistral models, extending the OpenAI chat completion response handler. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
* Handles non-streaming chat completion responses for Mistral models, extending the OpenAI chat completion response handler. | |
* Handles non-streaming completion responses for Mistral models, extending the OpenAI chat completion response handler. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed. Thanks.
...va/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsModel.java
Outdated
Show resolved
Hide resolved
.../org/elasticsearch/xpack/inference/services/openai/MistralChatCompletionResponseHandler.java
Outdated
Show resolved
Hide resolved
@@ -0,0 +1,51 @@ | |||
/* |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@prwhelan any suggestions on how to intentionally encounter a midstream error while testing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To followup, I don't think we have a great way to test this.
@Jan-Kazlouski-elastic can you try initiating a streaming request and disable your internet (or some sort of similar failure) in the middle of the response being streamed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I’ve tried this several times and consistently received read ECONNRESET
errors, these are client-side errors coming from Postman, not server-side ones. By shutting off the client’s internet, we’re preventing it from receiving any further stream data, including both the remaining valid chunks and any potential error payload.
If Mistral does construct an error response, it would still need to reach the client, but in this case, the client is no longer able to receive anything.
All errors returned by Mistral so far are strictly non-streaming. The only scenario where Mistral might return an error midstream would be due to a genuine server-side malfunction on their part. Or a rate limit, but testing it would be too expensive and even then, the error would probably be non-streaming.
Maybe we could contact Mistral to clarify that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we could contact Mistral to clarify that?
Yeah let's bring this up at the next meeting and see if we can get Serena to follow up.
* Handles streaming chat completion responses and error parsing for Mistral inference endpoints. | ||
* Adapts the OpenAI handler to support Mistral's simpler error schema with fields like "message" and "http_status_code". | ||
*/ | ||
public class MistralUnifiedChatCompletionResponseHandler extends OpenAiUnifiedChatCompletionResponseHandler { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the midstream errors are in the same format as openai, how about we refactor the OpenAiUnifiedChatCompletionResponseHandler
so that we can replace some of the strings that reference openai
specifically? I think it's just the name of the parser:
Lines 115 to 120 in 2830768
"open_ai_error", | |
true, | |
args -> Optional.ofNullable((OpenAiErrorResponse) args[0]) | |
); | |
private static final ConstructingObjectParser<OpenAiErrorResponse, Void> ERROR_BODY_PARSER = new ConstructingObjectParser<>( | |
"open_ai_error", |
Maybe we could extract those classes and rename them to be more generic?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
But we're not positive that midstream errors are in the same format as openai. It just assumed of the fact that Mistral uses OpenAI type API
…mistral-chat-completion-integration # Conflicts: # server/src/main/java/org/elasticsearch/TransportVersions.java
Hi @jonathan-buttner Also I updated the section related to testing. Three backticks didn't work for me when I was creating initial PR comment, but when I applied them afterwards - worked perfectly. Not sure why, but I will remember for the future PR creation. |
New error format: Create Completion EndpointNot Found:
Unauthorized:
Invalid Model:
Perform Non-Streaming CompletionNot Found:
Perform Streaming CompletionNot Found:
Create Chat Completion EndpointNot Found:
Unauthorized:
Invalid Model:
Perform Streaming Chat CompletionNot Found:
Negative Max Tokens:
Invalid Model:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the changes, left a few suggestions.
builder.startObject(STREAM_OPTIONS_FIELD); | ||
builder.field(INCLUDE_USAGE_FIELD, true); | ||
builder.endObject(); | ||
fillStreamOptionsFields(builder); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry what I meant is that we can use the same approach here. Let's leverage the Params
class to determine whether to serialize the stream options, the same way that the PR I listed is doing. That way we don't need to subclass.
if (this == o) return true; | ||
if (o == null || getClass() != o.getClass()) return false; | ||
MistralChatCompletionServiceSettings that = (MistralChatCompletionServiceSettings) o; | ||
return Objects.equals(modelId, that.modelId); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we need to include rateLimitSettings
here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missed that. Thank you. Fixed.
|
||
@Override | ||
public int hashCode() { | ||
return Objects.hash(modelId); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we need to include rateLimitSettings
here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missed that. Thank you. Fixed.
...e/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiResponseHandler.java
Show resolved
Hide resolved
} | ||
|
||
@Override | ||
protected Exception buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe this block is nearly identical for Hugging Face, Mistral, and OpenAI. Could you take a shot at refactoring it to remove the duplication? Maybe we can lift the instanceof
check out and somehow pass in the MISTRAL_ERROR
type field as it seems like those are the unique parts.
Could we do that in a separate PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will do that in separate PR and share the link here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sharing the PR:
#128923
import java.util.Objects; | ||
import java.util.Optional; | ||
|
||
public class StreamingErrorResponse extends ErrorResponse { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a comment with an example error message that this would parse? Let's also add a TODO to note that ErrorMessageResponseEntity
https://github.com/elastic/elasticsearch/blob/main/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/ErrorMessageResponseEntity.java is nearly identical (doesn't parse as many fields) and we should remove the duplication
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
|
||
public MistralChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, MistralChatCompletionModel model) { | ||
this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput); | ||
this.unifiedRequestEntity = new MistralUnifiedChatCompletionRequestEntity(unifiedChatInput); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we adjust this class to follow how we're doing it for openai now: https://github.com/elastic/elasticsearch/blob/main/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/request/OpenAiUnifiedChatCompletionRequestEntity.java
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
|
||
import java.nio.charset.StandardCharsets; | ||
|
||
/** |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a few examples of what the error format looks like, or maybe just add them as tests. If you go the test route, just add a comment saying look at the tests for the expected error formats that we're aware of.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did the change. I also removed the unifiedChatInput for Mistral. While for OpenAI it was left unused along with other fields:
private static final String MODEL_FIELD = "model";
private static final String MAX_COMPLETION_TOKENS_FIELD = "max_completion_tokens";
private final UnifiedChatInput unifiedChatInput;
I know that refactoring is heavily discouraged by CONTRIBUTING.md, but perhaps I could clean it up for OpenAI as well? Seems like a pretty easy fix.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
I know that refactoring is heavily discouraged by CONTRIBUTING.md, but perhaps I could clean it up for OpenAI as well? Seems like a pretty easy fix.
Yeah feel free to remove those unused variables. I think that improves the quality of the code. If it's small changes like you mentioned, I think it's fine. In future PRs if you encounter this situation, I would just leave a github review comment saying that these variables weren't used so removing, or something like that to make it clear to the reviewing why they're being removed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix for OpenAI refactoring is delivered!
BTW I meant to discuss this in thread for MistralRequestEntity. Misclicked.
…stream options parameter
…tings in equality checks
…completion-integration # Conflicts: # server/src/main/java/org/elasticsearch/TransportVersions.java
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! Left a few suggestions. I think this PR is in a good spot, just ping me when you have the other PR for the refactoring finished so we can merge that one first.
if (stream) { | ||
fillStreamOptionsFields(builder); | ||
// If request is streamed and skip stream options parameter is not true, include stream options in the request. | ||
if (stream == true && params.paramAsBoolean(SKIP_STREAM_OPTIONS_PARAM, false) == false) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: How about we reverse the naming, skip and false seems closer to a double negative to me so maybe:
if (stream && params.paramAsBoolean(INCLUDE_STREAM_OPTIONS_PARAM, true) == true) {
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good thinking. Defaulting boolean to true allows us not to fill it out for every other provider. Took another look at the CONTRIBUTING.md. According to it we should use == check for boolean values only in case we're checking for "false". So I replaced it with:
stream && params.paramAsBoolean(INCLUDE_STREAM_OPTIONS_PARAM, true)
Also extended the javadoc for INCLUDE_STREAM_OPTIONS_PARAM
import java.io.IOException; | ||
import java.util.ArrayList; | ||
|
||
import static org.elasticsearch.xpack.inference.Utils.assertJsonEquals; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where possible let's switch to using XContentHelper.stripWhitespace
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
|
||
import java.nio.charset.StandardCharsets; | ||
|
||
/** |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
I know that refactoring is heavily discouraged by CONTRIBUTING.md, but perhaps I could clean it up for OpenAI as well? Seems like a pretty easy fix.
Yeah feel free to remove those unused variables. I think that improves the quality of the code. If it's small changes like you mentioned, I think it's fine. In future PRs if you encounter this situation, I would just leave a github review comment saying that these variables weren't used so removing, or something like that to make it clear to the reviewing why they're being removed.
…elds and streamline constructor
…est to rename and update stream options parameter
…rtion and remove unused imports
…idate error handling
Pinging @elastic/ml-core (Team:ML) |
…completion-integration # Conflicts: # server/src/main/java/org/elasticsearch/TransportVersions.java
…completion-integration # Conflicts: # server/src/main/java/org/elasticsearch/TransportVersions.java
Looking good, looks like CI is failing with:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the changes! I did some testing and it looks good.
💔 Backport failed
You can use sqren/backport to manually backport by running |
💚 All backports created successfully
Questions ?Please refer to the Backport tool documentation |
Change to existing Mistral AI provider integration allowing completion (both streaming and non-streaming) and chat_completion (only streaming) to be executed as part of inference API.
Changes were tested against next models:
Notes:
Examples of RQ/RS from local testing:
Create Completion Endpoint
Success:
Unauthorized:
Not Found:
Invalid Model:
Perform Non-Streaming Completion
Success:
Not Found:
Perform Streaming Completion
Success:
Not Found:
Create Chat Completion Endpoint
Success:
Unauthorized:
Not Found:
Invalid Model:
Perform Streaming Chat Completion
Success:
Invalid Model:
Not Found:
gradle check
?