Skip to content

Commit 2830768

Browse files
authored
[ML] Integrate OpenAi Chat Completion in SageMaker (#127767)
SageMaker now supports Completion and Chat Completion using the OpenAI interfaces. Additionally: - Fixed bug related to timeouts being nullable, default to 30s timeout - Exposed existing OpenAi request/response parsing logic for reuse
1 parent 13f3864 commit 2830768

File tree

21 files changed

+669
-150
lines changed

21 files changed

+669
-150
lines changed

docs/changelog/127767.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 127767
2+
summary: Integrate `OpenAi` Chat Completion in `SageMaker`
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ static TransportVersion def(int id) {
180180
public static final TransportVersion ESQL_HASH_OPERATOR_STATUS_OUTPUT_TIME_8_19 = def(8_841_0_34);
181181
public static final TransportVersion RERANKER_FAILURES_ALLOWED_8_19 = def(8_841_0_35);
182182
public static final TransportVersion ML_INFERENCE_HUGGING_FACE_RERANK_ADDED_8_19 = def(8_841_0_36);
183+
public static final TransportVersion ML_INFERENCE_SAGEMAKER_CHAT_COMPLETION_8_19 = def(8_841_0_37);
183184
public static final TransportVersion V_9_0_0 = def(9_000_0_09);
184185
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10);
185186
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_2 = def(9_000_0_11);
@@ -264,6 +265,7 @@ static TransportVersion def(int id) {
264265
public static final TransportVersion NODES_STATS_SUPPORTS_MULTI_PROJECT = def(9_079_0_00);
265266
public static final TransportVersion ML_INFERENCE_HUGGING_FACE_RERANK_ADDED = def(9_080_0_00);
266267
public static final TransportVersion SETTINGS_IN_DATA_STREAMS_DRY_RUN = def(9_081_0_00);
268+
public static final TransportVersion ML_INFERENCE_SAGEMAKER_CHAT_COMPLETION = def(9_082_0_00);
267269
/*
268270
* STOP! READ THIS FIRST! No, really,
269271
* ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ public void testGetServicesWithRerankTaskType() throws IOException {
124124

125125
public void testGetServicesWithCompletionTaskType() throws IOException {
126126
List<Object> services = getServices(TaskType.COMPLETION);
127-
assertThat(services.size(), equalTo(11));
127+
assertThat(services.size(), equalTo(12));
128128

129129
var providers = providers(services);
130130

@@ -142,21 +142,24 @@ public void testGetServicesWithCompletionTaskType() throws IOException {
142142
"googleaistudio",
143143
"openai",
144144
"streaming_completion_test_service",
145-
"hugging_face"
145+
"hugging_face",
146+
"sagemaker"
146147
).toArray()
147148
)
148149
);
149150
}
150151

151152
public void testGetServicesWithChatCompletionTaskType() throws IOException {
152153
List<Object> services = getServices(TaskType.CHAT_COMPLETION);
153-
assertThat(services.size(), equalTo(5));
154+
assertThat(services.size(), equalTo(6));
154155

155156
var providers = providers(services);
156157

157158
assertThat(
158159
providers,
159-
containsInAnyOrder(List.of("deepseek", "elastic", "openai", "streaming_completion_test_service", "hugging_face").toArray())
160+
containsInAnyOrder(
161+
List.of("deepseek", "elastic", "openai", "streaming_completion_test_service", "hugging_face", "sagemaker").toArray()
162+
)
160163
);
161164
}
162165

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/DelegatingProcessor.java

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,12 @@
1212
import org.elasticsearch.xcontent.XContentParserConfiguration;
1313
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent;
1414

15-
import java.io.IOException;
1615
import java.util.ArrayDeque;
1716
import java.util.Deque;
18-
import java.util.Iterator;
1917
import java.util.concurrent.Flow;
2018
import java.util.concurrent.atomic.AtomicBoolean;
2119
import java.util.concurrent.atomic.AtomicLong;
20+
import java.util.stream.Stream;
2221

2322
/**
2423
* Processor that delegates the {@link java.util.concurrent.Flow.Subscription} to the upstream {@link java.util.concurrent.Flow.Publisher}
@@ -34,19 +33,13 @@ public abstract class DelegatingProcessor<T, R> implements Flow.Processor<T, R>
3433
public static <ParsedChunk> Deque<ParsedChunk> parseEvent(
3534
Deque<ServerSentEvent> item,
3635
ParseChunkFunction<ParsedChunk> parseFunction,
37-
XContentParserConfiguration parserConfig,
38-
Logger logger
39-
) throws Exception {
36+
XContentParserConfiguration parserConfig
37+
) {
4038
var results = new ArrayDeque<ParsedChunk>(item.size());
4139
for (ServerSentEvent event : item) {
4240
if (event.hasData()) {
43-
try {
44-
var delta = parseFunction.apply(parserConfig, event);
45-
delta.forEachRemaining(results::offer);
46-
} catch (Exception e) {
47-
logger.warn("Failed to parse event from inference provider: {}", event);
48-
throw e;
49-
}
41+
var delta = parseFunction.apply(parserConfig, event);
42+
delta.forEach(results::offer);
5043
}
5144
}
5245

@@ -55,7 +48,7 @@ public static <ParsedChunk> Deque<ParsedChunk> parseEvent(
5548

5649
@FunctionalInterface
5750
public interface ParseChunkFunction<ParsedChunk> {
58-
Iterator<ParsedChunk> apply(XContentParserConfiguration parserConfig, ServerSentEvent event) throws IOException;
51+
Stream<ParsedChunk> apply(XContentParserConfiguration parserConfig, ServerSentEvent event);
5952
}
6053

6154
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/unified/UnifiedChatCompletionRequestEntity.java

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,12 @@ public class UnifiedChatCompletionRequestEntity implements ToXContentFragment {
4545
private final boolean stream;
4646

4747
public UnifiedChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput) {
48-
Objects.requireNonNull(unifiedChatInput);
48+
this(Objects.requireNonNull(unifiedChatInput).getRequest(), Objects.requireNonNull(unifiedChatInput).stream());
49+
}
4950

50-
this.unifiedRequest = unifiedChatInput.getRequest();
51-
this.stream = unifiedChatInput.stream();
51+
public UnifiedChatCompletionRequestEntity(UnifiedCompletionRequest unifiedRequest, boolean stream) {
52+
this.unifiedRequest = Objects.requireNonNull(unifiedRequest);
53+
this.stream = stream;
5254
}
5355

5456
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiStreamingProcessor.java

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@
99

1010
import org.apache.logging.log4j.LogManager;
1111
import org.apache.logging.log4j.Logger;
12+
import org.elasticsearch.ElasticsearchStatusException;
1213
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
1314
import org.elasticsearch.inference.InferenceServiceResults;
15+
import org.elasticsearch.rest.RestStatus;
1416
import org.elasticsearch.xcontent.XContentFactory;
1517
import org.elasticsearch.xcontent.XContentParser;
1618
import org.elasticsearch.xcontent.XContentParserConfiguration;
@@ -20,11 +22,10 @@
2022
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent;
2123

2224
import java.io.IOException;
23-
import java.util.Collections;
2425
import java.util.Deque;
25-
import java.util.Iterator;
2626
import java.util.Objects;
2727
import java.util.function.Predicate;
28+
import java.util.stream.Stream;
2829

2930
import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
3031
import static org.elasticsearch.common.xcontent.XContentParserUtils.parseList;
@@ -113,7 +114,7 @@ public class OpenAiStreamingProcessor extends DelegatingProcessor<Deque<ServerSe
113114
@Override
114115
protected void next(Deque<ServerSentEvent> item) throws Exception {
115116
var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);
116-
var results = parseEvent(item, OpenAiStreamingProcessor::parse, parserConfig, log);
117+
var results = parseEvent(item, OpenAiStreamingProcessor::parse, parserConfig);
117118

118119
if (results.isEmpty()) {
119120
upstream().request(1);
@@ -122,10 +123,9 @@ protected void next(Deque<ServerSentEvent> item) throws Exception {
122123
}
123124
}
124125

125-
private static Iterator<StreamingChatCompletionResults.Result> parse(XContentParserConfiguration parserConfig, ServerSentEvent event)
126-
throws IOException {
126+
public static Stream<StreamingChatCompletionResults.Result> parse(XContentParserConfiguration parserConfig, ServerSentEvent event) {
127127
if (DONE_MESSAGE.equalsIgnoreCase(event.data())) {
128-
return Collections.emptyIterator();
128+
return Stream.empty();
129129
}
130130

131131
try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, event.data())) {
@@ -167,11 +167,14 @@ private static Iterator<StreamingChatCompletionResults.Result> parse(XContentPar
167167

168168
consumeUntilObjectEnd(parser); // end choices
169169
return ""; // stopped
170-
}).stream()
171-
.filter(Objects::nonNull)
172-
.filter(Predicate.not(String::isEmpty))
173-
.map(StreamingChatCompletionResults.Result::new)
174-
.iterator();
170+
}).stream().filter(Objects::nonNull).filter(Predicate.not(String::isEmpty)).map(StreamingChatCompletionResults.Result::new);
171+
} catch (IOException e) {
172+
throw new ElasticsearchStatusException(
173+
"Failed to parse event from inference provider: {}",
174+
RestStatus.INTERNAL_SERVER_ERROR,
175+
e,
176+
event
177+
);
175178
}
176179
}
177180
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedChatCompletionResponseHandler.java

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@ public OpenAiUnifiedChatCompletionResponseHandler(
5050
public InferenceServiceResults parseResult(Request request, Flow.Publisher<HttpResult> flow) {
5151
var serverSentEventProcessor = new ServerSentEventProcessor(new ServerSentEventParser());
5252
var openAiProcessor = new OpenAiUnifiedStreamingProcessor((m, e) -> buildMidStreamError(request, m, e));
53-
5453
flow.subscribe(serverSentEventProcessor);
5554
serverSentEventProcessor.subscribe(openAiProcessor);
5655
return new StreamingUnifiedChatCompletionResults(openAiProcessor);
@@ -81,14 +80,18 @@ protected static String createErrorType(ErrorResponse errorResponse) {
8180
}
8281

8382
protected Exception buildMidStreamError(Request request, String message, Exception e) {
83+
return buildMidStreamError(request.getInferenceEntityId(), message, e);
84+
}
85+
86+
public static UnifiedChatCompletionException buildMidStreamError(String inferenceEntityId, String message, Exception e) {
8487
var errorResponse = OpenAiErrorResponse.fromString(message);
8588
if (errorResponse instanceof OpenAiErrorResponse oer) {
8689
return new UnifiedChatCompletionException(
8790
RestStatus.INTERNAL_SERVER_ERROR,
8891
format(
8992
"%s for request from inference entity id [%s]. Error message: [%s]",
9093
SERVER_ERROR_OBJECT,
91-
request.getInferenceEntityId(),
94+
inferenceEntityId,
9295
errorResponse.getErrorMessage()
9396
),
9497
oer.type(),
@@ -100,7 +103,7 @@ protected Exception buildMidStreamError(Request request, String message, Excepti
100103
} else {
101104
return new UnifiedChatCompletionException(
102105
RestStatus.INTERNAL_SERVER_ERROR,
103-
format("%s for request from inference entity id [%s]", SERVER_ERROR_OBJECT, request.getInferenceEntityId()),
106+
format("%s for request from inference entity id [%s]", SERVER_ERROR_OBJECT, inferenceEntityId),
104107
createErrorType(errorResponse),
105108
"stream_error"
106109
);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedStreamingProcessor.java

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,10 @@
2222

2323
import java.io.IOException;
2424
import java.util.ArrayDeque;
25-
import java.util.Collections;
2625
import java.util.Deque;
27-
import java.util.Iterator;
2826
import java.util.List;
2927
import java.util.function.BiFunction;
28+
import java.util.stream.Stream;
3029

3130
import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
3231
import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken;
@@ -75,7 +74,7 @@ protected void next(Deque<ServerSentEvent> item) throws Exception {
7574
} else if (event.hasData()) {
7675
try {
7776
var delta = parse(parserConfig, event);
78-
delta.forEachRemaining(results::offer);
77+
delta.forEach(results::offer);
7978
} catch (Exception e) {
8079
logger.warn("Failed to parse event from inference provider: {}", event);
8180
throw errorParser.apply(event.data(), e);
@@ -90,12 +89,12 @@ protected void next(Deque<ServerSentEvent> item) throws Exception {
9089
}
9190
}
9291

93-
private static Iterator<StreamingUnifiedChatCompletionResults.ChatCompletionChunk> parse(
92+
public static Stream<StreamingUnifiedChatCompletionResults.ChatCompletionChunk> parse(
9493
XContentParserConfiguration parserConfig,
9594
ServerSentEvent event
9695
) throws IOException {
9796
if (DONE_MESSAGE.equalsIgnoreCase(event.data())) {
98-
return Collections.emptyIterator();
97+
return Stream.empty();
9998
}
10099

101100
try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, event.data())) {
@@ -106,7 +105,7 @@ private static Iterator<StreamingUnifiedChatCompletionResults.ChatCompletionChun
106105

107106
StreamingUnifiedChatCompletionResults.ChatCompletionChunk chunk = ChatCompletionChunkParser.parse(jsonParser);
108107

109-
return Collections.singleton(chunk).iterator();
108+
return Stream.of(chunk);
110109
}
111110
}
112111

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/response/OpenAiChatCompletionResponseEntity.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,11 @@ public class OpenAiChatCompletionResponseEntity {
6767
*/
6868

6969
public static ChatCompletionResults fromResponse(Request request, HttpResult response) throws IOException {
70-
try (var p = XContentFactory.xContent(XContentType.JSON).createParser(XContentParserConfiguration.EMPTY, response.body())) {
70+
return fromResponse(response.body());
71+
}
72+
73+
public static ChatCompletionResults fromResponse(byte[] response) throws IOException {
74+
try (var p = XContentFactory.xContent(XContentType.JSON).createParser(XContentParserConfiguration.EMPTY, response)) {
7175
return CompletionResult.PARSER.apply(p, null).toChatCompletionResults();
7276
}
7377
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerService.java

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
public class SageMakerService implements InferenceService {
4848
public static final String NAME = "sagemaker";
4949
private static final int DEFAULT_BATCH_SIZE = 256;
50+
private static final TimeValue DEFAULT_TIMEOUT = TimeValue.THIRTY_SECONDS;
5051
private final SageMakerModelBuilder modelBuilder;
5152
private final SageMakerClient client;
5253
private final SageMakerSchemas schemas;
@@ -128,7 +129,7 @@ public void infer(
128129
boolean stream,
129130
Map<String, Object> taskSettings,
130131
InputType inputType,
131-
TimeValue timeout,
132+
@Nullable TimeValue timeout,
132133
ActionListener<InferenceServiceResults> listener
133134
) {
134135
if (model instanceof SageMakerModel == false) {
@@ -148,7 +149,7 @@ public void infer(
148149
client.invokeStream(
149150
regionAndSecrets,
150151
request,
151-
timeout,
152+
timeout != null ? timeout : DEFAULT_TIMEOUT,
152153
ActionListener.wrap(
153154
response -> listener.onResponse(schema.streamResponse(sageMakerModel, response)),
154155
e -> listener.onFailure(schema.error(sageMakerModel, e))
@@ -160,7 +161,7 @@ public void infer(
160161
client.invoke(
161162
regionAndSecrets,
162163
request,
163-
timeout,
164+
timeout != null ? timeout : DEFAULT_TIMEOUT,
164165
ActionListener.wrap(
165166
response -> listener.onResponse(schema.response(sageMakerModel, response, threadPool.getThreadContext())),
166167
e -> listener.onFailure(schema.error(sageMakerModel, e))
@@ -201,7 +202,7 @@ private static ElasticsearchStatusException internalFailure(Model model, Excepti
201202
public void unifiedCompletionInfer(
202203
Model model,
203204
UnifiedCompletionRequest request,
204-
TimeValue timeout,
205+
@Nullable TimeValue timeout,
205206
ActionListener<InferenceServiceResults> listener
206207
) {
207208
if (model instanceof SageMakerModel == false) {
@@ -217,7 +218,7 @@ public void unifiedCompletionInfer(
217218
client.invokeStream(
218219
regionAndSecrets,
219220
sagemakerRequest,
220-
timeout,
221+
timeout != null ? timeout : DEFAULT_TIMEOUT,
221222
ActionListener.wrap(
222223
response -> listener.onResponse(schema.chatCompletionStreamResponse(sageMakerModel, response)),
223224
e -> listener.onFailure(schema.chatCompletionError(sageMakerModel, e))
@@ -235,7 +236,7 @@ public void chunkedInfer(
235236
List<ChunkInferenceInput> input,
236237
Map<String, Object> taskSettings,
237238
InputType inputType,
238-
TimeValue timeout,
239+
@Nullable TimeValue timeout,
239240
ActionListener<List<ChunkedInference>> listener
240241
) {
241242
if (model instanceof SageMakerModel == false) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerSchemas.java

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,12 @@
1212
import org.elasticsearch.inference.TaskType;
1313
import org.elasticsearch.rest.RestStatus;
1414
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel;
15+
import org.elasticsearch.xpack.inference.services.sagemaker.schema.openai.OpenAiCompletionPayload;
1516
import org.elasticsearch.xpack.inference.services.sagemaker.schema.openai.OpenAiTextEmbeddingPayload;
1617

1718
import java.util.Arrays;
1819
import java.util.EnumSet;
20+
import java.util.HashMap;
1921
import java.util.List;
2022
import java.util.Map;
2123
import java.util.Set;
@@ -39,7 +41,7 @@ public class SageMakerSchemas {
3941
/*
4042
* Add new model API to the register call.
4143
*/
42-
schemas = register(new OpenAiTextEmbeddingPayload());
44+
schemas = register(new OpenAiTextEmbeddingPayload(), new OpenAiCompletionPayload());
4345

4446
streamSchemas = schemas.entrySet()
4547
.stream()
@@ -88,7 +90,16 @@ public static List<NamedWriteableRegistry.Entry> namedWriteables() {
8890
)
8991
),
9092
schemas.values().stream().flatMap(SageMakerSchema::namedWriteables)
91-
).toList();
93+
)
94+
// Dedupe based on Entry name, we allow Payloads to declare the same Entry but the Registry does not handle duplicates
95+
.collect(
96+
() -> new HashMap<String, NamedWriteableRegistry.Entry>(),
97+
(map, entry) -> map.putIfAbsent(entry.name, entry),
98+
Map::putAll
99+
)
100+
.values()
101+
.stream()
102+
.toList();
92103
}
93104

94105
public SageMakerSchema schemaFor(SageMakerModel model) throws ElasticsearchStatusException {

0 commit comments

Comments
 (0)