Skip to content

Commit d1c317e

Browse files
[8.19] Add Mistral AI Chat Completion support to Inference Plugin (#128538) (#128947)
* Change versions for Mistral Chat Completion version * Refactor model handling in MistralService to use instanceof for cleaner code * Update Mistral Chat Completion 8.19 Version
1 parent 76829db commit d1c317e

File tree

42 files changed

+2603
-304
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+2603
-304
lines changed

docs/changelog/128538.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 128538
2+
summary: "Added Mistral Chat Completion support to the Inference Plugin"
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ static TransportVersion def(int id) {
237237
public static final TransportVersion ESQL_REGEX_MATCH_WITH_CASE_INSENSITIVITY_8_19 = def(8_841_0_44);
238238
public static final TransportVersion ESQL_QUERY_PLANNING_DURATION_8_19 = def(8_841_0_45);
239239
public static final TransportVersion SEARCH_SOURCE_EXCLUDE_VECTORS_PARAM_8_19 = def(8_841_0_46);
240-
240+
public static final TransportVersion ML_INFERENCE_MISTRAL_CHAT_COMPLETION_ADDED_8_19 = def(8_841_0_47);
241241
/*
242242
* STOP! READ THIS FIRST! No, really,
243243
* ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _

server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,14 @@ public record UnifiedCompletionRequest(
7878
* {@link #MAX_COMPLETION_TOKENS_FIELD}. Providers are expected to pass in their supported field name.
7979
*/
8080
private static final String MAX_TOKENS_PARAM = "max_tokens_field";
81+
/**
82+
* Indicates whether to include the `stream_options` field in the JSON output.
83+
* Some providers do not support this field. In such cases, this parameter should be set to "false",
84+
* and the `stream_options` field will be excluded from the output.
85+
* For providers that do support stream options, this parameter is left unset (default behavior),
86+
* which implicitly includes the `stream_options` field in the output.
87+
*/
88+
public static final String INCLUDE_STREAM_OPTIONS_PARAM = "include_stream_options";
8189

8290
/**
8391
* Creates a {@link org.elasticsearch.xcontent.ToXContent.Params} that causes ToXContent to include the key values:
@@ -91,6 +99,23 @@ public static Params withMaxTokens(String modelId, Params params) {
9199
);
92100
}
93101

102+
/**
103+
* Creates a {@link org.elasticsearch.xcontent.ToXContent.Params} that causes ToXContent to include the key values:
104+
* - Key: {@link #MODEL_FIELD}, Value: modelId
105+
* - Key: {@link #MAX_TOKENS_FIELD}, Value: {@link #MAX_TOKENS_FIELD}
106+
* - Key: {@link #INCLUDE_STREAM_OPTIONS_PARAM}, Value: "false"
107+
*/
108+
public static Params withMaxTokensAndSkipStreamOptionsField(String modelId, Params params) {
109+
return new DelegatingMapParams(
110+
Map.ofEntries(
111+
Map.entry(MODEL_ID_PARAM, modelId),
112+
Map.entry(MAX_TOKENS_PARAM, MAX_TOKENS_FIELD),
113+
Map.entry(INCLUDE_STREAM_OPTIONS_PARAM, Boolean.FALSE.toString())
114+
),
115+
params
116+
);
117+
}
118+
94119
/**
95120
* Creates a {@link org.elasticsearch.xcontent.ToXContent.Params} that causes ToXContent to include the key values:
96121
* - Key: {@link #MODEL_FIELD}, Value: modelId

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ public void testGetServicesWithRerankTaskType() throws IOException {
134134

135135
public void testGetServicesWithCompletionTaskType() throws IOException {
136136
List<Object> services = getServices(TaskType.COMPLETION);
137-
assertThat(services.size(), equalTo(13));
137+
assertThat(services.size(), equalTo(14));
138138

139139
var providers = providers(services);
140140

@@ -154,15 +154,16 @@ public void testGetServicesWithCompletionTaskType() throws IOException {
154154
"openai",
155155
"streaming_completion_test_service",
156156
"hugging_face",
157-
"amazon_sagemaker"
157+
"amazon_sagemaker",
158+
"mistral"
158159
).toArray()
159160
)
160161
);
161162
}
162163

163164
public void testGetServicesWithChatCompletionTaskType() throws IOException {
164165
List<Object> services = getServices(TaskType.CHAT_COMPLETION);
165-
assertThat(services.size(), equalTo(7));
166+
assertThat(services.size(), equalTo(8));
166167

167168
var providers = providers(services);
168169

@@ -176,7 +177,8 @@ public void testGetServicesWithChatCompletionTaskType() throws IOException {
176177
"streaming_completion_test_service",
177178
"hugging_face",
178179
"amazon_sagemaker",
179-
"googlevertexai"
180+
"googlevertexai",
181+
"mistral"
180182
).toArray()
181183
)
182184
);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@
100100
import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsTaskSettings;
101101
import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankServiceSettings;
102102
import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankTaskSettings;
103+
import org.elasticsearch.xpack.inference.services.mistral.completion.MistralChatCompletionServiceSettings;
103104
import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsServiceSettings;
104105
import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionServiceSettings;
105106
import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionTaskSettings;
@@ -266,6 +267,13 @@ private static void addMistralNamedWriteables(List<NamedWriteableRegistry.Entry>
266267
MistralEmbeddingsServiceSettings::new
267268
)
268269
);
270+
namedWriteables.add(
271+
new NamedWriteableRegistry.Entry(
272+
ServiceSettings.class,
273+
MistralChatCompletionServiceSettings.NAME,
274+
MistralChatCompletionServiceSettings::new
275+
)
276+
);
269277

270278
// note - no task settings for Mistral embeddings...
271279
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/ErrorMessageResponseEntity.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,13 @@
2121
* A pattern is emerging in how external providers provide error responses.
2222
*
2323
* At a minimum, these return:
24+
* <pre><code>
2425
* {
2526
* "error: {
2627
* "message": "(error message)"
2728
* }
2829
* }
29-
*
30+
* </code></pre>
3031
* Others may return additional information such as error codes specific to the service.
3132
*
3233
* This currently covers error handling for Azure AI Studio, however this pattern
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.response.streaming;
9+
10+
import org.elasticsearch.core.Nullable;
11+
import org.elasticsearch.xcontent.ConstructingObjectParser;
12+
import org.elasticsearch.xcontent.ParseField;
13+
import org.elasticsearch.xcontent.XContentFactory;
14+
import org.elasticsearch.xcontent.XContentParser;
15+
import org.elasticsearch.xcontent.XContentParserConfiguration;
16+
import org.elasticsearch.xcontent.XContentType;
17+
import org.elasticsearch.xpack.inference.external.http.HttpResult;
18+
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
19+
import org.elasticsearch.xpack.inference.external.response.ErrorMessageResponseEntity;
20+
21+
import java.util.Objects;
22+
import java.util.Optional;
23+
24+
/**
25+
* Represents an error response from a streaming inference service.
26+
* This class extends {@link ErrorResponse} and provides additional fields
27+
* specific to streaming errors, such as code, param, and type.
28+
* An example error response for a streaming service might look like:
29+
* <pre><code>
30+
* {
31+
* "error": {
32+
* "message": "Invalid input",
33+
* "code": "400",
34+
* "param": "input",
35+
* "type": "invalid_request_error"
36+
* }
37+
* }
38+
* </code></pre>
39+
* TODO: {@link ErrorMessageResponseEntity} is nearly identical to this, but doesn't parse as many fields. We must remove the duplication.
40+
*/
41+
public class StreamingErrorResponse extends ErrorResponse {
42+
private static final ConstructingObjectParser<Optional<ErrorResponse>, Void> ERROR_PARSER = new ConstructingObjectParser<>(
43+
"streaming_error",
44+
true,
45+
args -> Optional.ofNullable((StreamingErrorResponse) args[0])
46+
);
47+
private static final ConstructingObjectParser<StreamingErrorResponse, Void> ERROR_BODY_PARSER = new ConstructingObjectParser<>(
48+
"streaming_error",
49+
true,
50+
args -> new StreamingErrorResponse((String) args[0], (String) args[1], (String) args[2], (String) args[3])
51+
);
52+
53+
static {
54+
ERROR_BODY_PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField("message"));
55+
ERROR_BODY_PARSER.declareStringOrNull(ConstructingObjectParser.optionalConstructorArg(), new ParseField("code"));
56+
ERROR_BODY_PARSER.declareStringOrNull(ConstructingObjectParser.optionalConstructorArg(), new ParseField("param"));
57+
ERROR_BODY_PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField("type"));
58+
59+
ERROR_PARSER.declareObjectOrNull(
60+
ConstructingObjectParser.optionalConstructorArg(),
61+
ERROR_BODY_PARSER,
62+
null,
63+
new ParseField("error")
64+
);
65+
}
66+
67+
/**
68+
* Standard error response parser. This can be overridden for those subclasses that
69+
* have a different error response structure.
70+
* @param response The error response as an HttpResult
71+
*/
72+
public static ErrorResponse fromResponse(HttpResult response) {
73+
try (
74+
XContentParser parser = XContentFactory.xContent(XContentType.JSON)
75+
.createParser(XContentParserConfiguration.EMPTY, response.body())
76+
) {
77+
return ERROR_PARSER.apply(parser, null).orElse(ErrorResponse.UNDEFINED_ERROR);
78+
} catch (Exception e) {
79+
// swallow the error
80+
}
81+
82+
return ErrorResponse.UNDEFINED_ERROR;
83+
}
84+
85+
/**
86+
* Standard error response parser. This can be overridden for those subclasses that
87+
* have a different error response structure.
88+
* @param response The error response as a string
89+
*/
90+
public static ErrorResponse fromString(String response) {
91+
try (
92+
XContentParser parser = XContentFactory.xContent(XContentType.JSON).createParser(XContentParserConfiguration.EMPTY, response)
93+
) {
94+
return ERROR_PARSER.apply(parser, null).orElse(ErrorResponse.UNDEFINED_ERROR);
95+
} catch (Exception e) {
96+
// swallow the error
97+
}
98+
99+
return ErrorResponse.UNDEFINED_ERROR;
100+
}
101+
102+
@Nullable
103+
private final String code;
104+
@Nullable
105+
private final String param;
106+
private final String type;
107+
108+
StreamingErrorResponse(String errorMessage, @Nullable String code, @Nullable String param, String type) {
109+
super(errorMessage);
110+
this.code = code;
111+
this.param = param;
112+
this.type = Objects.requireNonNull(type);
113+
}
114+
115+
@Nullable
116+
public String code() {
117+
return code;
118+
}
119+
120+
@Nullable
121+
public String param() {
122+
return param;
123+
}
124+
125+
public String type() {
126+
return type;
127+
}
128+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/unified/UnifiedChatCompletionRequestEntity.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,12 @@
1515
import java.io.IOException;
1616
import java.util.Objects;
1717

18+
import static org.elasticsearch.inference.UnifiedCompletionRequest.INCLUDE_STREAM_OPTIONS_PARAM;
19+
20+
/**
21+
* Represents a unified chat completion request entity.
22+
* This class is used to convert the unified chat input into a format that can be serialized to XContent.
23+
*/
1824
public class UnifiedChatCompletionRequestEntity implements ToXContentFragment {
1925

2026
public static final String STREAM_FIELD = "stream";
@@ -42,7 +48,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
4248
builder.field(NUMBER_OF_RETURNED_CHOICES_FIELD, 1);
4349

4450
builder.field(STREAM_FIELD, stream);
45-
if (stream) {
51+
// If request is streamed and skip stream options parameter is not true, include stream options in the request.
52+
if (stream && params.paramAsBoolean(INCLUDE_STREAM_OPTIONS_PARAM, true)) {
4653
builder.startObject(STREAM_OPTIONS_FIELD);
4754
builder.field(INCLUDE_USAGE_FIELD, true);
4855
builder.endObject();
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.mistral;
9+
10+
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
11+
import org.elasticsearch.xpack.inference.services.mistral.response.MistralErrorResponse;
12+
import org.elasticsearch.xpack.inference.services.openai.OpenAiChatCompletionResponseHandler;
13+
14+
/**
15+
* Handles non-streaming completion responses for Mistral models, extending the OpenAI completion response handler.
16+
* This class is specifically designed to handle Mistral's error response format.
17+
*/
18+
public class MistralCompletionResponseHandler extends OpenAiChatCompletionResponseHandler {
19+
20+
/**
21+
* Constructs a MistralCompletionResponseHandler with the specified request type and response parser.
22+
*
23+
* @param requestType The type of request being handled (e.g., "mistral completions").
24+
* @param parseFunction The function to parse the response.
25+
*/
26+
public MistralCompletionResponseHandler(String requestType, ResponseParser parseFunction) {
27+
super(requestType, parseFunction, MistralErrorResponse::fromResponse);
28+
}
29+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralConstants.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
public class MistralConstants {
1111
public static final String API_EMBEDDINGS_PATH = "https://api.mistral.ai/v1/embeddings";
12+
public static final String API_COMPLETIONS_PATH = "https://api.mistral.ai/v1/chat/completions";
1213

1314
// note - there is no bounds information available from Mistral,
1415
// so we'll use a sane default here which is the same as Cohere's
@@ -18,4 +19,8 @@ public class MistralConstants {
1819
public static final String MODEL_FIELD = "model";
1920
public static final String INPUT_FIELD = "input";
2021
public static final String ENCODING_FORMAT_FIELD = "encoding_format";
22+
public static final String MAX_TOKENS_FIELD = "max_tokens";
23+
public static final String DETAIL_FIELD = "detail";
24+
public static final String MSG_FIELD = "msg";
25+
public static final String MESSAGE_FIELD = "message";
2126
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralEmbeddingsRequestManager.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import org.elasticsearch.xpack.inference.external.response.ErrorMessageResponseEntity;
2323
import org.elasticsearch.xpack.inference.services.azureopenai.response.AzureMistralOpenAiExternalResponseHandler;
2424
import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsModel;
25-
import org.elasticsearch.xpack.inference.services.mistral.request.MistralEmbeddingsRequest;
25+
import org.elasticsearch.xpack.inference.services.mistral.request.embeddings.MistralEmbeddingsRequest;
2626
import org.elasticsearch.xpack.inference.services.mistral.response.MistralEmbeddingsResponseEntity;
2727

2828
import java.util.List;

0 commit comments

Comments
 (0)