-
Notifications
You must be signed in to change notification settings - Fork 1.2k
feature: Adds support for Serverless inference #2831
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -32,6 +32,7 @@ | |
from sagemaker.inputs import CompilationInput | ||
from sagemaker.deprecations import removed_kwargs | ||
from sagemaker.predictor import PredictorBase | ||
from sagemaker.serverless import ServerlessInferenceConfig | ||
from sagemaker.transformer import Transformer | ||
|
||
LOGGER = logging.getLogger("sagemaker") | ||
|
@@ -209,7 +210,7 @@ def register( | |
model_package_arn=model_package.get("ModelPackageArn"), | ||
) | ||
|
||
def _init_sagemaker_session_if_does_not_exist(self, instance_type): | ||
def _init_sagemaker_session_if_does_not_exist(self, instance_type=None): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what does this do? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since serverless will need |
||
"""Set ``self.sagemaker_session`` to ``LocalSession`` or ``Session`` if it's not already. | ||
|
||
The type of session object is determined by the instance type. | ||
|
@@ -688,8 +689,8 @@ def compile( | |
|
||
def deploy( | ||
self, | ||
initial_instance_count, | ||
instance_type, | ||
initial_instance_count=None, | ||
instance_type=None, | ||
serializer=None, | ||
deserializer=None, | ||
accelerator_type=None, | ||
|
@@ -698,6 +699,7 @@ def deploy( | |
kms_key=None, | ||
wait=True, | ||
data_capture_config=None, | ||
serverless_inference_config=None, | ||
**kwargs, | ||
): | ||
"""Deploy this ``Model`` to an ``Endpoint`` and optionally return a ``Predictor``. | ||
|
@@ -715,9 +717,13 @@ def deploy( | |
|
||
Args: | ||
initial_instance_count (int): The initial number of instances to run | ||
in the ``Endpoint`` created from this ``Model``. | ||
in the ``Endpoint`` created from this ``Model``. If not using | ||
serverless inference, then it need to be a number larger or equals | ||
to 1 (default: None) | ||
instance_type (str): The EC2 instance type to deploy this Model to. | ||
For example, 'ml.p2.xlarge', or 'local' for local mode. | ||
For example, 'ml.p2.xlarge', or 'local' for local mode. If not using | ||
serverless inference, then it is required to deploy a model. | ||
(default: None) | ||
serializer (:class:`~sagemaker.serializers.BaseSerializer`): A | ||
serializer object, used to encode data for an inference endpoint | ||
(default: None). If ``serializer`` is not None, then | ||
|
@@ -746,7 +752,17 @@ def deploy( | |
data_capture_config (sagemaker.model_monitor.DataCaptureConfig): Specifies | ||
configuration related to Endpoint data capture for use with | ||
Amazon SageMaker Model Monitoring. Default: None. | ||
|
||
serverless_inference_config (sagemaker.serverless.ServerlessInferenceConfig): | ||
mufaddal-rohawala marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Specifies configuration related to serverless endpoint. Use this configuration | ||
when trying to create serverless endpoint and make serverless inference. If | ||
empty object passed through, we will use pre-defined values in | ||
``ServerlessInferenceConfig`` class to deploy serverless endpoint (default: None) | ||
Raises: | ||
ValueError: If arguments combination check failed in these circumstances: | ||
- If no role is specified or | ||
- If serverless inference config is not specified and instance type and instance | ||
count are also not specified or | ||
- If a wrong type of object is provided as serverless inference config | ||
Returns: | ||
callable[string, sagemaker.session.Session] or None: Invocation of | ||
``self.predictor_cls`` on the created endpoint name, if ``self.predictor_cls`` | ||
|
@@ -758,27 +774,47 @@ def deploy( | |
if self.role is None: | ||
raise ValueError("Role can not be null for deploying a model") | ||
|
||
if instance_type.startswith("ml.inf") and not self._is_compiled_model: | ||
is_serverless = serverless_inference_config is not None | ||
if not is_serverless and not (instance_type and initial_instance_count): | ||
raise ValueError( | ||
"Must specify instance type and instance count unless using serverless inference" | ||
) | ||
|
||
if is_serverless and not isinstance(serverless_inference_config, ServerlessInferenceConfig): | ||
raise ValueError( | ||
"serverless_inference_config needs to be a ServerlessInferenceConfig object" | ||
) | ||
|
||
if instance_type and instance_type.startswith("ml.inf") and not self._is_compiled_model: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what is this for? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as above |
||
LOGGER.warning( | ||
"Your model is not compiled. Please compile your model before using Inferentia." | ||
) | ||
|
||
compiled_model_suffix = "-".join(instance_type.split(".")[:-1]) | ||
if self._is_compiled_model: | ||
compiled_model_suffix = None if is_serverless else "-".join(instance_type.split(".")[:-1]) | ||
if self._is_compiled_model and not is_serverless: | ||
self._ensure_base_name_if_needed(self.image_uri) | ||
if self._base_name is not None: | ||
self._base_name = "-".join((self._base_name, compiled_model_suffix)) | ||
|
||
self._create_sagemaker_model(instance_type, accelerator_type, tags) | ||
|
||
serverless_inference_config_dict = ( | ||
serverless_inference_config._to_request_dict() if is_serverless else None | ||
) | ||
production_variant = sagemaker.production_variant( | ||
self.name, instance_type, initial_instance_count, accelerator_type=accelerator_type | ||
self.name, | ||
instance_type, | ||
initial_instance_count, | ||
accelerator_type=accelerator_type, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We are not supporting accelerators for serverless, will the below code impact in anyway? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This will be validated on the low-level botocore/boto3 libs, so it should be fine |
||
serverless_inference_config=serverless_inference_config_dict, | ||
) | ||
if endpoint_name: | ||
self.endpoint_name = endpoint_name | ||
else: | ||
base_endpoint_name = self._base_name or utils.base_from_name(self.name) | ||
if self._is_compiled_model and not base_endpoint_name.endswith(compiled_model_suffix): | ||
base_endpoint_name = "-".join((base_endpoint_name, compiled_model_suffix)) | ||
if self._is_compiled_model and not is_serverless: | ||
if not base_endpoint_name.endswith(compiled_model_suffix): | ||
base_endpoint_name = "-".join((base_endpoint_name, compiled_model_suffix)) | ||
self.endpoint_name = utils.name_from_base(base_endpoint_name) | ||
|
||
data_capture_config_dict = None | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
"""This module contains code related to the ServerlessInferenceConfig class. | ||
|
||
Codes are used for configuring async inference endpoint. Use it when deploying | ||
the model to the endpoints. | ||
""" | ||
from __future__ import print_function, absolute_import | ||
|
||
|
||
class ServerlessInferenceConfig(object): | ||
"""Configuration object passed in when deploying models to Amazon SageMaker Endpoints. | ||
|
||
This object specifies configuration related to serverless endpoint. Use this configuration | ||
when trying to create serverless endpoint and make serverless inference | ||
""" | ||
|
||
def __init__( | ||
self, | ||
memory_size_in_mb=2048, | ||
max_concurrency=5, | ||
): | ||
"""Initialize a ServerlessInferenceConfig object for serverless inference configuration. | ||
|
||
Args: | ||
memory_size_in_mb (int): Optional. The memory size of your serverless endpoint. | ||
Valid values are in 1 GB increments: 1024 MB, 2048 MB, 3072 MB, 4096 MB, | ||
5120 MB, or 6144 MB. If no value is provided, Amazon SageMaker will choose | ||
the default value for you. (Default: 2048) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Default is 1 GB, please check with Michael once There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. BOTO3 API there is no default for memory and maxconcurrency shown in the Boto3 API. I can double check with Michael. The reason why we set default here in Sagemaker SDK is to simpler customer use case. They don't need to specify this config and we'll have a proper value for them. And on the other hand, if they want to set those, they can choose the config they prefer. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep, the default is controlled by the SageMaker SDK here - there's no inherent default in the CreateEndpointConfig API itself. |
||
max_concurrency (int): Optional. The maximum number of concurrent invocations | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a default maxconcurrency with Boto3 today? good to double check with Michael There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not with boto3. |
||
your serverless endpoint can process. If no value is provided, Amazon | ||
SageMaker will choose the default value for you. (Default: 5) | ||
""" | ||
self.memory_size_in_mb = memory_size_in_mb | ||
self.max_concurrency = max_concurrency | ||
|
||
def _to_request_dict(self): | ||
"""Generates a request dictionary using the parameters provided to the class.""" | ||
request_dict = { | ||
"MemorySizeInMB": self.memory_size_in_mb, | ||
"MaxConcurrency": self.max_concurrency, | ||
} | ||
|
||
return request_dict |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I understand correct - if customers makes it none, then it means serverless else instance based ? right ?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the confusing, I think the more reasonable logic is:
serverless_inference_config
to infer whether users are using serverless inference. And if that config is not None, we think they are using serverless inference. And in that case, we allowinstance_type
andinstance_count
to be None.serverless_inference_config
is None andinstance_type
andinstance_count
is not None, that's fine. Sinceinstance_type
andinstance_count
will be ignored when buildingproduction_variant
whenserverless_inference_config
is set. And we also have a low-level BOTO3 validation for API request's parameters.serverless_inference_config
is None, theninstance_type
andinstance_count
must have value since it's required for instance based inference.(THE FOLLOWING IS THE ORIGINAL COMMENT, MAY BE CONFUSING)
If customer leave this as None,
instance_type
as None and provides anserverless_inference_config
, that means we're using serverless inference.