Skip to content

feature: Adds support for Serverless inference #2831

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def read_version():
# Declare minimal set for installation
required_packages = [
"attrs",
"boto3>=1.20.18",
"boto3>=1.20.21",
"google-pasta",
"numpy>=1.9.0",
"protobuf>=3.1",
Expand Down
26 changes: 19 additions & 7 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -852,8 +852,8 @@ def logs(self):

def deploy(
self,
initial_instance_count,
instance_type,
initial_instance_count=None,
instance_type=None,
serializer=None,
deserializer=None,
accelerator_type=None,
Expand All @@ -864,6 +864,7 @@ def deploy(
kms_key=None,
data_capture_config=None,
tags=None,
serverless_inference_config=None,
**kwargs,
):
"""Deploy the trained model to an Amazon SageMaker endpoint.
Expand All @@ -874,10 +875,14 @@ def deploy(
http://docs.aws.amazon.com/sagemaker/latest/dg/how-it-works-training.html

Args:
initial_instance_count (int): Minimum number of EC2 instances to
deploy to an endpoint for prediction.
instance_type (str): Type of EC2 instance to deploy to an endpoint
for prediction, for example, 'ml.c4.xlarge'.
initial_instance_count (int): The initial number of instances to run
in the ``Endpoint`` created from this ``Model``. If not using
serverless inference, then it need to be a number larger or equals
to 1 (default: None)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correct - if customers makes it none, then it means serverless else instance based ? right ?

Copy link
Contributor Author

@bhaoz bhaoz Jan 11, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the confusing, I think the more reasonable logic is:

  • we use serverless_inference_config to infer whether users are using serverless inference. And if that config is not None, we think they are using serverless inference. And in that case, we allow instance_type and instance_count to be None.
  • If serverless_inference_config is None and instance_type and instance_count is not None, that's fine. Since instance_type and instance_count will be ignored when building production_variant when serverless_inference_config is set. And we also have a low-level BOTO3 validation for API request's parameters.
  • If serverless_inference_config is None, then instance_type and instance_count must have value since it's required for instance based inference.

(THE FOLLOWING IS THE ORIGINAL COMMENT, MAY BE CONFUSING)
If customer leave this as None, instance_type as None and provides an serverless_inference_config, that means we're using serverless inference.

instance_type (str): The EC2 instance type to deploy this Model to.
For example, 'ml.p2.xlarge', or 'local' for local mode. If not using
serverless inference, then it is required to deploy a model.
(default: None)
serializer (:class:`~sagemaker.serializers.BaseSerializer`): A
serializer object, used to encode data for an inference endpoint
(default: None). If ``serializer`` is not None, then
Expand Down Expand Up @@ -910,6 +915,11 @@ def deploy(
data_capture_config (sagemaker.model_monitor.DataCaptureConfig): Specifies
configuration related to Endpoint data capture for use with
Amazon SageMaker Model Monitoring. Default: None.
serverless_inference_config (sagemaker.serverless.ServerlessInferenceConfig):
Specifies configuration related to serverless endpoint. Use this configuration
when trying to create serverless endpoint and make serverless inference. If
empty object passed through, we will use pre-defined values in
``ServerlessInferenceConfig`` class to deploy serverless endpoint (default: None)
tags(List[dict[str, str]]): Optional. The list of tags to attach to this specific
endpoint. Example:
>>> tags = [{'Key': 'tagname', 'Value': 'tagvalue'}]
Expand All @@ -927,14 +937,15 @@ def deploy(
endpoint and obtain inferences.
"""
removed_kwargs("update_endpoint", kwargs)
is_serverless = serverless_inference_config is not None
self._ensure_latest_training_job()
self._ensure_base_job_name()
default_name = name_from_base(self.base_job_name)
endpoint_name = endpoint_name or default_name
model_name = model_name or default_name

self.deploy_instance_type = instance_type
if use_compiled_model:
if use_compiled_model and not is_serverless:
family = "_".join(instance_type.split(".")[:-1])
if family not in self._compiled_models:
raise ValueError(
Expand All @@ -959,6 +970,7 @@ def deploy(
wait=wait,
kms_key=kms_key,
data_capture_config=data_capture_config,
serverless_inference_config=serverless_inference_config,
)

def register(
Expand Down
60 changes: 48 additions & 12 deletions src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from sagemaker.inputs import CompilationInput
from sagemaker.deprecations import removed_kwargs
from sagemaker.predictor import PredictorBase
from sagemaker.serverless import ServerlessInferenceConfig
from sagemaker.transformer import Transformer

LOGGER = logging.getLogger("sagemaker")
Expand Down Expand Up @@ -209,7 +210,7 @@ def register(
model_package_arn=model_package.get("ModelPackageArn"),
)

def _init_sagemaker_session_if_does_not_exist(self, instance_type):
def _init_sagemaker_session_if_does_not_exist(self, instance_type=None):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does this do?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since serverless will need instance_type to be None, so add this to support serverless case.

"""Set ``self.sagemaker_session`` to ``LocalSession`` or ``Session`` if it's not already.

The type of session object is determined by the instance type.
Expand Down Expand Up @@ -688,8 +689,8 @@ def compile(

def deploy(
self,
initial_instance_count,
instance_type,
initial_instance_count=None,
instance_type=None,
serializer=None,
deserializer=None,
accelerator_type=None,
Expand All @@ -698,6 +699,7 @@ def deploy(
kms_key=None,
wait=True,
data_capture_config=None,
serverless_inference_config=None,
**kwargs,
):
"""Deploy this ``Model`` to an ``Endpoint`` and optionally return a ``Predictor``.
Expand All @@ -715,9 +717,13 @@ def deploy(

Args:
initial_instance_count (int): The initial number of instances to run
in the ``Endpoint`` created from this ``Model``.
in the ``Endpoint`` created from this ``Model``. If not using
serverless inference, then it need to be a number larger or equals
to 1 (default: None)
instance_type (str): The EC2 instance type to deploy this Model to.
For example, 'ml.p2.xlarge', or 'local' for local mode.
For example, 'ml.p2.xlarge', or 'local' for local mode. If not using
serverless inference, then it is required to deploy a model.
(default: None)
serializer (:class:`~sagemaker.serializers.BaseSerializer`): A
serializer object, used to encode data for an inference endpoint
(default: None). If ``serializer`` is not None, then
Expand Down Expand Up @@ -746,7 +752,17 @@ def deploy(
data_capture_config (sagemaker.model_monitor.DataCaptureConfig): Specifies
configuration related to Endpoint data capture for use with
Amazon SageMaker Model Monitoring. Default: None.

serverless_inference_config (sagemaker.serverless.ServerlessInferenceConfig):
Specifies configuration related to serverless endpoint. Use this configuration
when trying to create serverless endpoint and make serverless inference. If
empty object passed through, we will use pre-defined values in
``ServerlessInferenceConfig`` class to deploy serverless endpoint (default: None)
Raises:
ValueError: If arguments combination check failed in these circumstances:
- If no role is specified or
- If serverless inference config is not specified and instance type and instance
count are also not specified or
- If a wrong type of object is provided as serverless inference config
Returns:
callable[string, sagemaker.session.Session] or None: Invocation of
``self.predictor_cls`` on the created endpoint name, if ``self.predictor_cls``
Expand All @@ -758,27 +774,47 @@ def deploy(
if self.role is None:
raise ValueError("Role can not be null for deploying a model")

if instance_type.startswith("ml.inf") and not self._is_compiled_model:
is_serverless = serverless_inference_config is not None
if not is_serverless and not (instance_type and initial_instance_count):
raise ValueError(
"Must specify instance type and instance count unless using serverless inference"
)

if is_serverless and not isinstance(serverless_inference_config, ServerlessInferenceConfig):
raise ValueError(
"serverless_inference_config needs to be a ServerlessInferenceConfig object"
)

if instance_type and instance_type.startswith("ml.inf") and not self._is_compiled_model:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is this for?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above

LOGGER.warning(
"Your model is not compiled. Please compile your model before using Inferentia."
)

compiled_model_suffix = "-".join(instance_type.split(".")[:-1])
if self._is_compiled_model:
compiled_model_suffix = None if is_serverless else "-".join(instance_type.split(".")[:-1])
if self._is_compiled_model and not is_serverless:
self._ensure_base_name_if_needed(self.image_uri)
if self._base_name is not None:
self._base_name = "-".join((self._base_name, compiled_model_suffix))

self._create_sagemaker_model(instance_type, accelerator_type, tags)

serverless_inference_config_dict = (
serverless_inference_config._to_request_dict() if is_serverless else None
)
production_variant = sagemaker.production_variant(
self.name, instance_type, initial_instance_count, accelerator_type=accelerator_type
self.name,
instance_type,
initial_instance_count,
accelerator_type=accelerator_type,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are not supporting accelerators for serverless, will the below code impact in anyway?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will be validated on the low-level botocore/boto3 libs, so it should be fine

serverless_inference_config=serverless_inference_config_dict,
)
if endpoint_name:
self.endpoint_name = endpoint_name
else:
base_endpoint_name = self._base_name or utils.base_from_name(self.name)
if self._is_compiled_model and not base_endpoint_name.endswith(compiled_model_suffix):
base_endpoint_name = "-".join((base_endpoint_name, compiled_model_suffix))
if self._is_compiled_model and not is_serverless:
if not base_endpoint_name.endswith(compiled_model_suffix):
base_endpoint_name = "-".join((base_endpoint_name, compiled_model_suffix))
self.endpoint_name = utils.name_from_base(base_endpoint_name)

data_capture_config_dict = None
Expand Down
3 changes: 3 additions & 0 deletions src/sagemaker/serverless/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,6 @@
"""Classes for performing machine learning on serverless compute."""
from sagemaker.serverless.model import LambdaModel # noqa: F401
from sagemaker.serverless.predictor import LambdaPredictor # noqa: F401
from sagemaker.serverless.serverless_inference_config import ( # noqa: F401
ServerlessInferenceConfig,
)
54 changes: 54 additions & 0 deletions src/sagemaker/serverless/serverless_inference_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""This module contains code related to the ServerlessInferenceConfig class.

Codes are used for configuring async inference endpoint. Use it when deploying
the model to the endpoints.
"""
from __future__ import print_function, absolute_import


class ServerlessInferenceConfig(object):
"""Configuration object passed in when deploying models to Amazon SageMaker Endpoints.

This object specifies configuration related to serverless endpoint. Use this configuration
when trying to create serverless endpoint and make serverless inference
"""

def __init__(
self,
memory_size_in_mb=2048,
max_concurrency=5,
):
"""Initialize a ServerlessInferenceConfig object for serverless inference configuration.

Args:
memory_size_in_mb (int): Optional. The memory size of your serverless endpoint.
Valid values are in 1 GB increments: 1024 MB, 2048 MB, 3072 MB, 4096 MB,
5120 MB, or 6144 MB. If no value is provided, Amazon SageMaker will choose
the default value for you. (Default: 2048)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Default is 1 GB, please check with Michael once

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BOTO3 API there is no default for memory and maxconcurrency shown in the Boto3 API. I can double check with Michael. The reason why we set default here in Sagemaker SDK is to simpler customer use case. They don't need to specify this config and we'll have a proper value for them. And on the other hand, if they want to set those, they can choose the config they prefer.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, the default is controlled by the SageMaker SDK here - there's no inherent default in the CreateEndpointConfig API itself.

max_concurrency (int): Optional. The maximum number of concurrent invocations
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a default maxconcurrency with Boto3 today? good to double check with Michael

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not with boto3.

your serverless endpoint can process. If no value is provided, Amazon
SageMaker will choose the default value for you. (Default: 5)
"""
self.memory_size_in_mb = memory_size_in_mb
self.max_concurrency = max_concurrency

def _to_request_dict(self):
"""Generates a request dictionary using the parameters provided to the class."""
request_dict = {
"MemorySizeInMB": self.memory_size_in_mb,
"MaxConcurrency": self.max_concurrency,
}

return request_dict
17 changes: 13 additions & 4 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -4382,11 +4382,12 @@ def pipeline_container_def(models, instance_type=None):

def production_variant(
model_name,
instance_type,
initial_instance_count=1,
instance_type=None,
initial_instance_count=None,
variant_name="AllTraffic",
initial_weight=1,
accelerator_type=None,
serverless_inference_config=None,
):
"""Create a production variant description suitable for use in a ``ProductionVariant`` list.

Expand All @@ -4405,21 +4406,29 @@ def production_variant(
accelerator_type (str): Type of Elastic Inference accelerator for this production variant.
For example, 'ml.eia1.medium'.
For more information: https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html
serverless_inference_config (dict): Specifies configuration dict related to serverless
endpoint. The dict is converted from sagemaker.model_monitor.ServerlessInferenceConfig
object (default: None)

Returns:
dict[str, str]: An SageMaker ``ProductionVariant`` description
"""
production_variant_configuration = {
"ModelName": model_name,
"InstanceType": instance_type,
"InitialInstanceCount": initial_instance_count,
"VariantName": variant_name,
"InitialVariantWeight": initial_weight,
}

if accelerator_type:
production_variant_configuration["AcceleratorType"] = accelerator_type

if serverless_inference_config:
production_variant_configuration["ServerlessConfig"] = serverless_inference_config
else:
initial_instance_count = initial_instance_count or 1
production_variant_configuration["InitialInstanceCount"] = initial_instance_count
production_variant_configuration["InstanceType"] = instance_type

return production_variant_configuration


Expand Down
6 changes: 4 additions & 2 deletions src/sagemaker/tensorflow/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,8 @@ def register(

def deploy(
self,
initial_instance_count,
instance_type,
initial_instance_count=None,
instance_type=None,
serializer=None,
deserializer=None,
accelerator_type=None,
Expand All @@ -269,6 +269,7 @@ def deploy(
wait=True,
data_capture_config=None,
update_endpoint=None,
serverless_inference_config=None,
):
"""Deploy a Tensorflow ``Model`` to a SageMaker ``Endpoint``."""

Expand All @@ -287,6 +288,7 @@ def deploy(
kms_key=kms_key,
wait=wait,
data_capture_config=data_capture_config,
serverless_inference_config=serverless_inference_config,
update_endpoint=update_endpoint,
)

Expand Down
Loading