Skip to content

Commit cce4d12

Browse files
committed
feature: add support for async inference
1 parent 554d735 commit cce4d12

22 files changed

+1483
-2
lines changed

doc/api/inference/async_inference.rst

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
Async Inference
2+
-----------------
3+
4+
This module contains classes related to Amazon Sagemaker Async Inference
5+
6+
.. automodule:: sagemaker.async_inference.async_inference_config
7+
:members:
8+
:undoc-members:
9+
:show-inheritance:
10+
11+
.. automodule:: sagemaker.async_inference.async_inference_response
12+
:members:
13+
:undoc-members:
14+
:show-inheritance:
15+
16+
.. automodule:: sagemaker.async_inference.waiter_config
17+
:members:
18+
:undoc-members:
19+
:show-inheritance:

doc/api/inference/predictor_async.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
AsyncPredictor
2+
--------------------
3+
4+
Make async predictions against SageMaker endpoints with Python objects
5+
6+
.. autoclass:: sagemaker.predictor_async.AsyncPredictor
7+
:members:
8+
:undoc-members:
9+
:show-inheritance:
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Imports the classes in this module to simplify customer imports"""
14+
15+
from __future__ import absolute_import
16+
17+
from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig # noqa: F401
18+
from sagemaker.async_inference.waiter_config import WaiterConfig # noqa: F401
19+
from sagemaker.async_inference.async_inference_response import AsyncInferenceResponse # noqa: F401
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""A class for AsyncInferenceConfig
14+
15+
Codes are used for configuring async inference endpoint. Use it when deploying
16+
the model to the endpoints.
17+
"""
18+
from __future__ import print_function, absolute_import
19+
20+
21+
class AsyncInferenceConfig(object):
22+
"""Configuration object passed in when deploying models to Amazon SageMaker Endpoints.
23+
24+
This object specifies configuration related to async endpoint. Use this configuration
25+
when trying to create async endpoint and make async inference
26+
"""
27+
28+
def __init__(
29+
self,
30+
output_path=None,
31+
max_concurrent_invocations_per_instance=None,
32+
kms_key_id=None,
33+
notification_config=None,
34+
):
35+
"""Initialize an AsyncInferenceConfig object for async inference related configuration.
36+
37+
Args:
38+
output_path (str): Optional. The Amazon S3 location that endpoints upload
39+
inference responses to. If no value is provided, Amazon SageMaker will
40+
use default Amazon S3 Async Inference output path. (Default: None)
41+
max_concurrent_invocations_per_instance (int): Optional. The maximum number of
42+
concurrent requests sent by the SageMaker client to the model container. If
43+
no value is provided, Amazon SageMaker will choose an optimal value for you.
44+
(Default: None)
45+
kms_key_id (str): Optional. The Amazon Web Services Key Management Service
46+
(Amazon Web Services KMS) key that Amazon SageMaker uses to encrypt the
47+
asynchronous inference output in Amazon S3. (Default: None)
48+
notification_config (dict): Optional. Specifies the configuration for notifications
49+
of inference results for asynchronous inference (Default: None):
50+
* success_topic (str): Amazon SNS topic to post a notification to when inference
51+
completes successfully. If no topic is provided, no notification is sent on success.
52+
The key in notification_config is 'SuccessTopic'.
53+
* error_topic (str): Amazon SNS topic to post a notification to when inference
54+
fails. If no topic is provided, no notification is sent on failure.
55+
The key in notification_config is 'ErrorTopic'.
56+
"""
57+
self.output_path = output_path
58+
self.max_concurrent_invocations_per_instance = max_concurrent_invocations_per_instance
59+
self.kms_key_id = kms_key_id
60+
self.notification_config = notification_config
61+
62+
def _to_request_dict(self):
63+
"""Generates a request dictionary using the parameters provided to the class."""
64+
request_dict = {
65+
"OutputConfig": {
66+
"S3OutputPath": self.output_path,
67+
},
68+
}
69+
70+
if self.max_concurrent_invocations_per_instance:
71+
request_dict["ClientConfig"] = {
72+
"MaxConcurrentInvocationsPerInstance": self.max_concurrent_invocations_per_instance
73+
}
74+
75+
if self.kms_key_id:
76+
request_dict["OutputConfig"]["KmsKeyId"] = self.kms_key_id
77+
78+
if self.notification_config:
79+
request_dict["OutputConfig"]["NotificationConfig"] = self.notification_config
80+
81+
return request_dict
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""A class for AsyncInferenceResponse"""
14+
15+
from __future__ import print_function, absolute_import
16+
17+
from botocore.exceptions import ClientError
18+
from sagemaker.s3 import parse_s3_url
19+
from sagemaker.async_inference import WaiterConfig
20+
from sagemaker.exceptions import ObjectNotExistedError, UnexpectedClientError
21+
22+
23+
24+
class AsyncInferenceResponse(object):
25+
"""Response from Async Inference endpoint
26+
27+
This response object provides a method to check the async Amazon S3
28+
output path. If result object exists in that path, decode and return
29+
the result
30+
"""
31+
32+
def __init__(
33+
self,
34+
predictor_async,
35+
output_path,
36+
):
37+
"""Initialize an AsyncInferenceResponse object.
38+
39+
AsyncInferenceResponse can help users to get async inference result
40+
from the Amazon S3 output path
41+
42+
Args:
43+
predictor_async (sagemaker.predictor.AsyncPredictor): The ``AsyncPredictor``
44+
that return this response.
45+
output_path (str): The Amazon S3 location that endpoints upload inference responses
46+
to.
47+
"""
48+
self.predictor_async = predictor_async
49+
self.output_path = output_path
50+
self._result = None
51+
52+
def get_result(
53+
self,
54+
waiter_config=None,
55+
):
56+
"""Get result from the async Amazon S3 output path
57+
58+
Args:
59+
waiter_config (sagemaker.async_inference.waiter_config.WaiterConfig): Configuration
60+
for the waiter. The pre-defined value for the delay between poll is 15 seconds
61+
and the default max attempts is 60
62+
Raises:
63+
ValueError: If a wrong type of object is provided as ``waiter_config``
64+
Returns:
65+
object: Inference result in the given Amazon S3 output path. If a deserializer was
66+
specified when creating the AsyncPredictor, the result of the deserializer is
67+
returned. Otherwise the response returns the sequence of bytes
68+
as is.
69+
"""
70+
if waiter_config is not None and not isinstance(waiter_config, WaiterConfig):
71+
raise ValueError("waiter_config should be a WaiterConfig object")
72+
73+
if self._result is None:
74+
if waiter_config is None:
75+
self._result = self._get_result_from_s3(self.output_path)
76+
else:
77+
self._result = self.predictor_async._wait_for_output(
78+
self.output_path, waiter_config
79+
)
80+
return self._result
81+
82+
def _get_result_from_s3(
83+
self,
84+
output_path,
85+
):
86+
"""Get inference result from the output Amazon S3 path"""
87+
bucket, key = parse_s3_url(output_path)
88+
try:
89+
response = self.predictor_async.s3_client.get_object(Bucket=bucket, Key=key)
90+
return self.predictor_async.predictor._handle_response(response)
91+
except ClientError as ex:
92+
if ex.response["Error"]["Code"] == "NoSuchKey":
93+
raise ObjectNotExistedError(
94+
message="Inference could still be running",
95+
output_path=output_path,
96+
)
97+
raise UnexpectedClientError(
98+
message=ex.response["Error"]["Message"],
99+
)
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""A class for WaiterConfig used in async inference
14+
15+
Use it when using async inference and wait for the result.
16+
"""
17+
18+
from __future__ import absolute_import
19+
20+
21+
class WaiterConfig(object):
22+
"""Configuration object passed in when using async inference and wait for the result."""
23+
24+
def __init__(
25+
self,
26+
max_attempts=60,
27+
delay=15,
28+
):
29+
"""Initialize a WaiterConfig object that provides parameters to control waiting behavior.
30+
31+
Args:
32+
max_attempts (int): The maximum number of attempts to be made. (Default: 60)
33+
delay (int): The amount of time in seconds to wait between attempts. (Default: 15)
34+
"""
35+
36+
self.max_attempts = max_attempts
37+
self.delay = delay
38+
39+
def _to_waiter_dict(self):
40+
"""Generates a dictionary using the parameters provided to the class."""
41+
waiter_dict = {
42+
"Delay": self.delay,
43+
"MaxAttempts": self.max_attempts,
44+
}
45+
46+
return waiter_dict

src/sagemaker/estimator.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -864,6 +864,7 @@ def deploy(
864864
kms_key=None,
865865
data_capture_config=None,
866866
tags=None,
867+
async_inference_config=None,
867868
**kwargs,
868869
):
869870
"""Deploy the trained model to an Amazon SageMaker endpoint.
@@ -910,6 +911,11 @@ def deploy(
910911
data_capture_config (sagemaker.model_monitor.DataCaptureConfig): Specifies
911912
configuration related to Endpoint data capture for use with
912913
Amazon SageMaker Model Monitoring. Default: None.
914+
async_inference_config (sagemaker.model_monitor.AsyncInferenceConfig): Specifies
915+
configuration related to async endpoint. Use this configuration when trying
916+
to create async endpoint and make async inference. If empty config object
917+
passed through, we will use default config to deploy async endpoint
918+
(default: None)
913919
tags(List[dict[str, str]]): Optional. The list of tags to attach to this specific
914920
endpoint. Example:
915921
>>> tags = [{'Key': 'tagname', 'Value': 'tagvalue'}]
@@ -959,6 +965,7 @@ def deploy(
959965
wait=wait,
960966
kms_key=kms_key,
961967
data_capture_config=data_capture_config,
968+
async_inference_config=async_inference_config,
962969
)
963970

964971
def register(

src/sagemaker/exceptions.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,43 @@ def __init__(self, message, allowed_statuses, actual_status):
2121
self.allowed_statuses = allowed_statuses
2222
self.actual_status = actual_status
2323
super(UnexpectedStatusException, self).__init__(message)
24+
25+
26+
class AsyncInferenceError(Exception):
27+
"""The base exception class for Async Inference exceptions."""
28+
29+
fmt = "An unspecified error occurred"
30+
31+
def __init__(self, **kwargs):
32+
msg = self.fmt.format(**kwargs)
33+
Exception.__init__(self, msg)
34+
self.kwargs = kwargs
35+
36+
37+
class ObjectNotExistedError(AsyncInferenceError):
38+
"""Raised when Amazon S3 object not exist in the given path"""
39+
40+
fmt = "Object not exist at {output_path}. {message}"
41+
42+
def __init__(self, message, output_path):
43+
super(ObjectNotExistedError, self).__init__(message=message, output_path=output_path)
44+
45+
46+
class PollingTimeoutError(AsyncInferenceError):
47+
"""Raised when wait longer than expected and no result object in Amazon S3 bucket yet"""
48+
49+
fmt = "No result at {output_path} after polling for {seconds} seconds. {message}"
50+
51+
def __init__(self, message, output_path, seconds):
52+
super(PollingTimeoutError, self).__init__(
53+
message=message, output_path=output_path, seconds=seconds
54+
)
55+
56+
57+
class UnexpectedClientError(AsyncInferenceError):
58+
"""Raised when ClientError's error code is not expected"""
59+
60+
fmt = "Encountered unexpected client error: {message}"
61+
62+
def __init__(self, message):
63+
super(UnexpectedClientError, self).__init__(message=message)

0 commit comments

Comments
 (0)