Skip to content

Commit bc39826

Browse files
Merge branch 'dev' into jumpstart-dev
2 parents 89001ec + 90b0b0f commit bc39826

34 files changed

+1836
-96
lines changed

doc/api/inference/async_inference.rst

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
Async Inference
2+
-----------------
3+
4+
This module contains classes related to Amazon Sagemaker Async Inference
5+
6+
.. automodule:: sagemaker.async_inference.async_inference_config
7+
:members:
8+
:undoc-members:
9+
:show-inheritance:
10+
11+
.. automodule:: sagemaker.async_inference.async_inference_response
12+
:members:
13+
:undoc-members:
14+
:show-inheritance:
15+
16+
.. automodule:: sagemaker.async_inference.waiter_config
17+
:members:
18+
:undoc-members:
19+
:show-inheritance:

doc/api/inference/predictor_async.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
AsyncPredictor
2+
--------------------
3+
4+
Make async predictions against SageMaker endpoints with Python objects
5+
6+
.. autoclass:: sagemaker.predictor_async.AsyncPredictor
7+
:members:
8+
:undoc-members:
9+
:show-inheritance:

doc/conf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import pkg_resources
1717
from datetime import datetime
1818

19-
project = u"sagemaker"
19+
project = "sagemaker"
2020
version = pkg_resources.require(project)[0].version
2121

2222
# Add any Sphinx extension module names here, as strings. They can be extensions
@@ -38,7 +38,7 @@
3838
source_suffix = ".rst" # The suffix of source filenames.
3939
master_doc = "index" # The master toctree document.
4040

41-
copyright = u"%s, Amazon" % datetime.now().year
41+
copyright = "%s, Amazon" % datetime.now().year
4242

4343
# The full version, including alpha/beta/rc tags.
4444
release = version

doc/overview.rst

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1121,6 +1121,98 @@ For more detailed explanations of the classes that this library provides for aut
11211121
- `API docs for HyperparameterTuner and parameter range classes <https://sagemaker.readthedocs.io/en/stable/tuner.html>`__
11221122
- `API docs for analytics classes <https://sagemaker.readthedocs.io/en/stable/analytics.html>`__
11231123

1124+
**********************************
1125+
SageMaker Asynchronous Inference
1126+
**********************************
1127+
Amazon SageMaker Asynchronous Inference is a new capability in SageMaker that queues incoming requests and processes them asynchronously.
1128+
This option is ideal for requests with large payload sizes up to 1GB, long processing times, and near real-time latency requirements.
1129+
You can configure Asynchronous Inference scale the instance count to zero when there are no requests to process, thereby saving costs.
1130+
More information about SageMaker Asynchronous Inference can be found in the `AWS documentation <https://docs.aws.amazon.com/sagemaker/latest/dg/async-inference.html>`__.
1131+
1132+
To deploy asynchronous inference endpoint, you will need to create a ``AsyncInferenceConfig`` object.
1133+
If you create ``AsyncInferenceConfig`` without specifying its arguments, the default ``S3OutputPath`` will
1134+
be ``s3://sagemaker-{REGION}-{ACCOUNTID}/async-endpoint-outputs/{UNIQUE-JOB-NAME}``. (example shown below):
1135+
1136+
.. code:: python
1137+
1138+
from sagemaker.async_inference import AsyncInferenceConfig
1139+
1140+
# Create an empty AsyncInferenceConfig object to use default values
1141+
async_config = new AsyncInferenceConfig()
1142+
1143+
Or you can specify configurations in ``AsyncInferenceConfig`` as you like. All of those configuration parameters
1144+
are optional but if you don’t specify the ``output_path``, Amazon SageMaker will use the default ``S3OutputPath``
1145+
mentioned above (example shown below):
1146+
1147+
.. code:: python
1148+
1149+
# Specify S3OutputPath, MaxConcurrentInvocationsPerInstance and NotificationConfig in the async config object
1150+
async_config = new AsyncInferenceConfig(
1151+
output_path="s3://{s3_bucket}/{bucket_prefix}/output",
1152+
max_concurrent_invocations_per_instance=10,
1153+
notification_config = {
1154+
"SuccessTopic": "arn:aws:sns:aws-region:account-id:topic-name",
1155+
"ErrorTopic": "arn:aws:sns:aws-region:account-id:topic-name",
1156+
}
1157+
)
1158+
1159+
Then use the ``AsyncInferenceConfig`` in the estimator's ``deploy()`` method to deploy an asynchronous inference endpoint:
1160+
1161+
.. code:: python
1162+
1163+
# Deploys the model that was generated by fit() to a SageMaker asynchronous inference endpoint
1164+
async_predictor = estimator.deploy(async_inference_config=async_config)
1165+
1166+
After deployment is complete, it will return an ``AsyncPredictor`` object. To perform asynchronous inference, you first
1167+
need to upload data to S3 and then use the ``predict_async()`` method with the s3 URI as the input. It will return an
1168+
``AsyncInferenceResponse`` object:
1169+
1170+
.. code:: python
1171+
1172+
# Upload data to S3 bucket then use that as input
1173+
async_response = async_predictor.predict_async(input_path=input_s3_path)
1174+
1175+
The Amazon SageMaker SDK also enables you to serialize the data and pass the payload data directly to the
1176+
``predict_async()`` method. For this pattern of invocation, the Amazon SageMaker SDK will upload the data to an Amazon
1177+
S3 bucket under ``s3://sagemaker-{REGION}-{ACCOUNTID}/async-endpoint-inputs/``.
1178+
1179+
.. code:: python
1180+
1181+
# Serializes data and makes a prediction request to the SageMaker asynchronous endpoint
1182+
async_response = async_predictor.predict_async(data=data)
1183+
1184+
Then you can switch to other stuff and wait the inference to complete. After it is completed, you can check
1185+
the result using ``AsyncInferenceResponse``:
1186+
1187+
.. code:: python
1188+
1189+
# Switch back to check the result
1190+
result = async_response.get_result()
1191+
1192+
Alternatively, if you would like to check for a result periodically and return it upon generation, use the
1193+
``predict()`` method
1194+
1195+
.. code:: python
1196+
1197+
# Use predict() to wait for the result
1198+
response = async_predictor.predict(data=data)
1199+
1200+
# Or use Amazon S3 input path
1201+
response = async_predictor.predict(input_path=input_s3_path)
1202+
1203+
Clean up the endpoint and model if needed after inference:
1204+
1205+
.. code:: python
1206+
1207+
# Tears down the SageMaker endpoint and endpoint configuration
1208+
async_predictor.delete_endpoint()
1209+
1210+
# Deletes the SageMaker model
1211+
async_predictor.delete_model()
1212+
1213+
For more details about Asynchronous Inference,
1214+
see the API docs for `Asynchronous Inference <https://sagemaker.readthedocs.io/en/stable/api/inference/async_inference.html>`__
1215+
11241216
*******************************
11251217
SageMaker Serverless Inference
11261218
*******************************

src/sagemaker/analytics.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,11 @@ def training_job_summaries(self, force_refresh=False):
261261
)
262262
new_output = raw_result["TrainingJobSummaries"]
263263
output.extend(new_output)
264-
logger.debug("Got %d more TrainingJobs. Total so far: %d", len(new_output), len(output))
264+
logger.debug(
265+
"Got %d more TrainingJobs. Total so far: %d",
266+
len(new_output),
267+
len(output),
268+
)
265269
if ("NextToken" in raw_result) and (len(new_output) > 0):
266270
next_args["NextToken"] = raw_result["NextToken"]
267271
else:
@@ -344,7 +348,7 @@ def _determine_timeinterval(self):
344348
a dict with the `start_time` and `end_time`.
345349
"""
346350
description = self._sage_client.describe_training_job(TrainingJobName=self.name)
347-
start_time = self._start_time or description[u"TrainingStartTime"] # datetime object
351+
start_time = self._start_time or description["TrainingStartTime"] # datetime object
348352
# Incrementing end time by 1 min since CloudWatch drops seconds before finding the logs.
349353
# This results in logs being searched in the time range in which the correct log line was
350354
# not present.
@@ -353,7 +357,7 @@ def _determine_timeinterval(self):
353357
# CW will consider end time as 2018-10-22 08:25 and will not be able to search the
354358
# correct log.
355359
end_time = self._end_time or description.get(
356-
u"TrainingEndTime", datetime.datetime.utcnow()
360+
"TrainingEndTime", datetime.datetime.utcnow()
357361
) + datetime.timedelta(minutes=1)
358362

359363
return {"start_time": start_time, "end_time": end_time}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Imports the classes in this module to simplify customer imports"""
14+
15+
from __future__ import absolute_import
16+
17+
from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig # noqa: F401
18+
from sagemaker.async_inference.waiter_config import WaiterConfig # noqa: F401
19+
from sagemaker.async_inference.async_inference_response import AsyncInferenceResponse # noqa: F401
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""A class for AsyncInferenceConfig
14+
15+
Used for configuring async inference endpoint. Use AsyncInferenceConfig when deploying
16+
the model to the async inference endpoints.
17+
"""
18+
from __future__ import print_function, absolute_import
19+
20+
21+
class AsyncInferenceConfig(object):
22+
"""Configuration object passed in when deploying models to Amazon SageMaker Endpoints.
23+
24+
This object specifies configuration related to async endpoint. Use this configuration
25+
when trying to create async endpoint and make async inference
26+
"""
27+
28+
def __init__(
29+
self,
30+
output_path=None,
31+
max_concurrent_invocations_per_instance=None,
32+
kms_key_id=None,
33+
notification_config=None,
34+
):
35+
"""Initialize an AsyncInferenceConfig object for async inference configuration.
36+
37+
Args:
38+
output_path (str): Optional. The Amazon S3 location that endpoints upload
39+
inference responses to. If no value is provided, Amazon SageMaker will
40+
use default Amazon S3 Async Inference output path. (Default: None)
41+
max_concurrent_invocations_per_instance (int): Optional. The maximum number of
42+
concurrent requests sent by the SageMaker client to the model container. If
43+
no value is provided, Amazon SageMaker will choose an optimal value for you.
44+
(Default: None)
45+
kms_key_id (str): Optional. The Amazon Web Services Key Management Service
46+
(Amazon Web Services KMS) key that Amazon SageMaker uses to encrypt the
47+
asynchronous inference output in Amazon S3. (Default: None)
48+
notification_config (dict): Optional. Specifies the configuration for notifications
49+
of inference results for asynchronous inference. Only one notification is generated
50+
per invocation request (Default: None):
51+
* success_topic (str): Amazon SNS topic to post a notification to when inference
52+
completes successfully. If no topic is provided, no notification is sent on success.
53+
The key in notification_config is 'SuccessTopic'.
54+
* error_topic (str): Amazon SNS topic to post a notification to when inference
55+
fails. If no topic is provided, no notification is sent on failure.
56+
The key in notification_config is 'ErrorTopic'.
57+
"""
58+
self.output_path = output_path
59+
self.max_concurrent_invocations_per_instance = max_concurrent_invocations_per_instance
60+
self.kms_key_id = kms_key_id
61+
self.notification_config = notification_config
62+
63+
def _to_request_dict(self):
64+
"""Generates a request dictionary using the parameters provided to the class."""
65+
request_dict = {
66+
"OutputConfig": {
67+
"S3OutputPath": self.output_path,
68+
},
69+
}
70+
71+
if self.max_concurrent_invocations_per_instance:
72+
request_dict["ClientConfig"] = {
73+
"MaxConcurrentInvocationsPerInstance": self.max_concurrent_invocations_per_instance
74+
}
75+
76+
if self.kms_key_id:
77+
request_dict["OutputConfig"]["KmsKeyId"] = self.kms_key_id
78+
79+
if self.notification_config:
80+
request_dict["OutputConfig"]["NotificationConfig"] = self.notification_config
81+
82+
return request_dict
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""A class for AsyncInferenceResponse"""
14+
15+
from __future__ import print_function, absolute_import
16+
17+
from botocore.exceptions import ClientError
18+
from sagemaker.s3 import parse_s3_url
19+
from sagemaker.async_inference import WaiterConfig
20+
from sagemaker.exceptions import ObjectNotExistedError, UnexpectedClientError
21+
22+
23+
class AsyncInferenceResponse(object):
24+
"""Response from Async Inference endpoint
25+
26+
This response object provides a method to check for an async inference result in the
27+
Amazon S3 output path specified. If result object exists in that path, get and return
28+
the result
29+
"""
30+
31+
def __init__(
32+
self,
33+
predictor_async,
34+
output_path,
35+
):
36+
"""Initialize an AsyncInferenceResponse object.
37+
38+
AsyncInferenceResponse can help users to get async inference result
39+
from the Amazon S3 output path
40+
41+
Args:
42+
predictor_async (sagemaker.predictor.AsyncPredictor): The ``AsyncPredictor``
43+
that return this response.
44+
output_path (str): The Amazon S3 location that endpoints upload inference responses
45+
to.
46+
"""
47+
self.predictor_async = predictor_async
48+
self.output_path = output_path
49+
self._result = None
50+
51+
def get_result(
52+
self,
53+
waiter_config=None,
54+
):
55+
"""Get async inference result in the Amazon S3 output path specified
56+
57+
Args:
58+
waiter_config (sagemaker.async_inference.waiter_config.WaiterConfig): Configuration
59+
for the waiter. The pre-defined value for the delay between poll is 15 seconds
60+
and the default max attempts is 60
61+
Raises:
62+
ValueError: If a wrong type of object is provided as ``waiter_config``
63+
Returns:
64+
object: Inference result in the given Amazon S3 output path. If a deserializer was
65+
specified when creating the AsyncPredictor, the result of the deserializer is
66+
returned. Otherwise the response returns the sequence of bytes
67+
as is.
68+
"""
69+
if waiter_config is not None and not isinstance(waiter_config, WaiterConfig):
70+
raise ValueError("waiter_config should be a WaiterConfig object")
71+
72+
if self._result is None:
73+
if waiter_config is None:
74+
self._result = self._get_result_from_s3(self.output_path)
75+
else:
76+
self._result = self.predictor_async._wait_for_output(
77+
self.output_path, waiter_config
78+
)
79+
return self._result
80+
81+
def _get_result_from_s3(
82+
self,
83+
output_path,
84+
):
85+
"""Get inference result from the output Amazon S3 path"""
86+
bucket, key = parse_s3_url(output_path)
87+
try:
88+
response = self.predictor_async.s3_client.get_object(Bucket=bucket, Key=key)
89+
return self.predictor_async.predictor._handle_response(response)
90+
except ClientError as ex:
91+
if ex.response["Error"]["Code"] == "NoSuchKey":
92+
raise ObjectNotExistedError(
93+
message="Inference could still be running",
94+
output_path=output_path,
95+
)
96+
raise UnexpectedClientError(
97+
message=ex.response["Error"]["Message"],
98+
)

0 commit comments

Comments
 (0)