Skip to content

Commit 972379e

Browse files
authored
feature: HuggingFace Inference (#2511)
1 parent efdf3ec commit 972379e

File tree

9 files changed

+529
-34
lines changed

9 files changed

+529
-34
lines changed

src/sagemaker/huggingface/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@
1414
from __future__ import absolute_import
1515

1616
from sagemaker.huggingface.estimator import HuggingFace # noqa: F401
17+
from sagemaker.huggingface.model import HuggingFaceModel, HuggingFacePredictor # noqa: F401

src/sagemaker/huggingface/estimator.py

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
warn_if_parameter_server_with_multi_gpu,
2424
validate_smdistributed,
2525
)
26+
from sagemaker.huggingface.model import HuggingFaceModel
2627
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
2728

2829
logger = logging.getLogger("sagemaker")
@@ -233,8 +234,58 @@ def create_model(
233234
dependencies=None,
234235
**kwargs
235236
):
236-
"""Placeholder docstring"""
237-
raise NotImplementedError("Creating model with HuggingFace training job is not supported.")
237+
"""Create a SageMaker ``HuggingFaceModel`` object that can be deployed to an ``Endpoint``.
238+
239+
Args:
240+
model_server_workers (int): Optional. The number of worker processes
241+
used by the inference server. If None, server will use one
242+
worker per vCPU.
243+
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``,
244+
which is also used during transform jobs. If not specified, the
245+
role from the Estimator will be used.
246+
vpc_config_override (dict[str, list[str]]): Optional override for VpcConfig set on
247+
the model. Default: use subnets and security groups from this Estimator.
248+
* 'Subnets' (list[str]): List of subnet ids.
249+
* 'SecurityGroupIds' (list[str]): List of security group ids.
250+
entry_point (str): Path (absolute or relative) to the local Python source file which
251+
should be executed as the entry point to training. If ``source_dir`` is specified,
252+
then ``entry_point`` must point to a file located at the root of ``source_dir``.
253+
Defaults to `None`.
254+
source_dir (str): Path (absolute or relative) to a directory with any other serving
255+
source code dependencies aside from the entry point file.
256+
If not specified, the model source directory from training is used.
257+
dependencies (list[str]): A list of paths to directories (absolute or relative) with
258+
any additional libraries that will be exported to the container.
259+
If not specified, the dependencies from training are used.
260+
This is not supported with "local code" in Local Mode.
261+
**kwargs: Additional kwargs passed to the :class:`~sagemaker.huggingface.model.HuggingFaceModel`
262+
constructor.
263+
Returns:
264+
sagemaker.huggingface.model.HuggingFaceModel: A SageMaker ``HuggingFaceModel``
265+
object. See :func:`~sagemaker.huggingface.model.HuggingFaceModel` for full details.
266+
"""
267+
if "image_uri" not in kwargs:
268+
kwargs["image_uri"] = self.image_uri
269+
270+
kwargs["name"] = self._get_or_create_name(kwargs.get("name"))
271+
272+
return HuggingFaceModel(
273+
role or self.role,
274+
model_data=self.model_data,
275+
entry_point=entry_point,
276+
transformers_version=self.framework_version,
277+
tensorflow_version=self.tensorflow_version,
278+
pytorch_version=self.pytorch_version,
279+
py_version=self.py_version,
280+
source_dir=(source_dir or self._model_source_dir()),
281+
container_log_level=self.container_log_level,
282+
code_location=self.code_location,
283+
model_server_workers=model_server_workers,
284+
sagemaker_session=self.sagemaker_session,
285+
vpc_config=self.get_vpc_config(vpc_config_override),
286+
dependencies=(dependencies or self.dependencies),
287+
**kwargs
288+
)
238289

239290
@classmethod
240291
def _prepare_init_params_from_job_description(cls, job_details, model_channel_name=None):

src/sagemaker/huggingface/model.py

Lines changed: 314 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,314 @@
1+
# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Placeholder docstring"""
14+
from __future__ import absolute_import
15+
16+
import logging
17+
18+
import sagemaker
19+
from sagemaker import image_uris
20+
from sagemaker.deserializers import JSONDeserializer
21+
from sagemaker.fw_utils import (
22+
model_code_key_prefix,
23+
validate_version_or_image_args,
24+
)
25+
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
26+
from sagemaker.predictor import Predictor
27+
from sagemaker.serializers import JSONSerializer
28+
29+
logger = logging.getLogger("sagemaker")
30+
31+
32+
class HuggingFacePredictor(Predictor):
33+
"""A Predictor for inference against HuggingFace Endpoints.
34+
35+
This is able to serialize Python lists, dictionaries, and numpy arrays to
36+
multidimensional tensors for HuggingFace inference.
37+
"""
38+
39+
def __init__(
40+
self,
41+
endpoint_name,
42+
sagemaker_session=None,
43+
serializer=JSONSerializer(),
44+
deserializer=JSONDeserializer(),
45+
):
46+
"""Initialize an ``HuggingFacePredictor``.
47+
48+
Args:
49+
endpoint_name (str): The name of the endpoint to perform inference
50+
on.
51+
sagemaker_session (sagemaker.session.Session): Session object which
52+
manages interactions with Amazon SageMaker APIs and any other
53+
AWS services needed. If not specified, the estimator creates one
54+
using the default AWS configuration chain.
55+
serializer (sagemaker.serializers.BaseSerializer): Optional. Default
56+
serializes input data to .npy format. Handles lists and numpy
57+
arrays.
58+
deserializer (sagemaker.deserializers.BaseDeserializer): Optional.
59+
Default parses the response from .npy format to numpy array.
60+
"""
61+
super(HuggingFacePredictor, self).__init__(
62+
endpoint_name,
63+
sagemaker_session,
64+
serializer=serializer,
65+
deserializer=deserializer,
66+
)
67+
68+
69+
def _validate_pt_tf_versions(pytorch_version, tensorflow_version, image_uri):
70+
"""Placeholder docstring"""
71+
72+
if image_uri is not None:
73+
return
74+
75+
if tensorflow_version is not None and pytorch_version is not None:
76+
raise ValueError(
77+
"tensorflow_version and pytorch_version are both not None. "
78+
"Specify only tensorflow_version or pytorch_version."
79+
)
80+
if tensorflow_version is None and pytorch_version is None:
81+
raise ValueError(
82+
"tensorflow_version and pytorch_version are both None. "
83+
"Specify either tensorflow_version or pytorch_version."
84+
)
85+
86+
87+
class HuggingFaceModel(FrameworkModel):
88+
"""An HuggingFace SageMaker ``Model`` that can be deployed to a SageMaker ``Endpoint``."""
89+
90+
_framework_name = "huggingface"
91+
92+
def __init__(
93+
self,
94+
role,
95+
model_data=None,
96+
entry_point=None,
97+
transformers_version=None,
98+
tensorflow_version=None,
99+
pytorch_version=None,
100+
py_version=None,
101+
image_uri=None,
102+
predictor_cls=HuggingFacePredictor,
103+
model_server_workers=None,
104+
**kwargs,
105+
):
106+
"""Initialize a HuggingFaceModel.
107+
108+
Args:
109+
model_data (str): The S3 location of a SageMaker model data
110+
``.tar.gz`` file.
111+
role (str): An AWS IAM role (either name or full ARN). The Amazon
112+
SageMaker training jobs and APIs that create Amazon SageMaker
113+
endpoints use this role to access training data and model
114+
artifacts. After the endpoint is created, the inference code
115+
might use the IAM role, if it needs to access an AWS resource.
116+
entry_point (str): Path (absolute or relative) to the Python source
117+
file which should be executed as the entry point to model
118+
hosting. If ``source_dir`` is specified, then ``entry_point``
119+
must point to a file located at the root of ``source_dir``.
120+
Defaults to None.
121+
transformers_version (str): transformers version you want to use for
122+
executing your model training code. Defaults to None. Required
123+
unless ``image_uri`` is provided.
124+
tensorflow_version (str): TensorFlow version you want to use for
125+
executing your inference code. Defaults to ``None``. Required unless
126+
``pytorch_version`` is provided. List of supported versions:
127+
https://github.com/aws/sagemaker-python-sdk#huggingface-sagemaker-estimators.
128+
pytorch_version (str): PyTorch version you want to use for
129+
executing your inference code. Defaults to ``None``. Required unless
130+
``tensorflow_version`` is provided. List of supported versions:
131+
https://github.com/aws/sagemaker-python-sdk#huggingface-sagemaker-estimators.
132+
py_version (str): Python version you want to use for executing your
133+
model training code. Defaults to ``None``. Required unless
134+
``image_uri`` is provided.
135+
image_uri (str): A Docker image URI (default: None). If not specified, a
136+
default image for PyTorch will be used. If ``framework_version``
137+
or ``py_version`` are ``None``, then ``image_uri`` is required. If
138+
also ``None``, then a ``ValueError`` will be raised.
139+
predictor_cls (callable[str, sagemaker.session.Session]): A function
140+
to call to create a predictor with an endpoint name and
141+
SageMaker ``Session``. If specified, ``deploy()`` returns the
142+
result of invoking this function on the created endpoint name.
143+
model_server_workers (int): Optional. The number of worker processes
144+
used by the inference server. If None, server will use one
145+
worker per vCPU.
146+
**kwargs: Keyword arguments passed to the superclass
147+
:class:`~sagemaker.model.FrameworkModel` and, subsequently, its
148+
superclass :class:`~sagemaker.model.Model`.
149+
150+
.. tip::
151+
152+
You can find additional parameters for initializing this class at
153+
:class:`~sagemaker.model.FrameworkModel` and
154+
:class:`~sagemaker.model.Model`.
155+
"""
156+
validate_version_or_image_args(transformers_version, py_version, image_uri)
157+
_validate_pt_tf_versions(
158+
pytorch_version=pytorch_version,
159+
tensorflow_version=tensorflow_version,
160+
image_uri=image_uri,
161+
)
162+
if py_version == "py2":
163+
raise ValueError("py2 is not supported with HuggingFace images")
164+
self.framework_version = transformers_version
165+
self.pytorch_version = pytorch_version
166+
self.tensorflow_version = tensorflow_version
167+
self.py_version = py_version
168+
169+
super(HuggingFaceModel, self).__init__(
170+
model_data, image_uri, role, entry_point, predictor_cls=predictor_cls, **kwargs
171+
)
172+
173+
self.model_server_workers = model_server_workers
174+
175+
def register(
176+
self,
177+
content_types,
178+
response_types,
179+
inference_instances,
180+
transform_instances,
181+
model_package_name=None,
182+
model_package_group_name=None,
183+
image_uri=None,
184+
model_metrics=None,
185+
metadata_properties=None,
186+
marketplace_cert=False,
187+
approval_status=None,
188+
description=None,
189+
):
190+
"""Creates a model package for creating SageMaker models or listing on Marketplace.
191+
192+
Args:
193+
content_types (list): The supported MIME types for the input data.
194+
response_types (list): The supported MIME types for the output data.
195+
inference_instances (list): A list of the instance types that are used to
196+
generate inferences in real-time.
197+
transform_instances (list): A list of the instance types on which a transformation
198+
job can be run or on which an endpoint can be deployed.
199+
model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
200+
using `model_package_name` makes the Model Package un-versioned (default: None).
201+
model_package_group_name (str): Model Package Group name, exclusive to
202+
`model_package_name`, using `model_package_group_name` makes the Model Package
203+
versioned (default: None).
204+
image_uri (str): Inference image uri for the container. Model class' self.image will
205+
be used if it is None (default: None).
206+
model_metrics (ModelMetrics): ModelMetrics object (default: None).
207+
metadata_properties (MetadataProperties): MetadataProperties object (default: None).
208+
marketplace_cert (bool): A boolean value indicating if the Model Package is certified
209+
for AWS Marketplace (default: False).
210+
approval_status (str): Model Approval Status, values can be "Approved", "Rejected",
211+
or "PendingManualApproval" (default: "PendingManualApproval").
212+
description (str): Model Package description (default: None).
213+
214+
Returns:
215+
A `sagemaker.model.ModelPackage` instance.
216+
"""
217+
instance_type = inference_instances[0]
218+
self._init_sagemaker_session_if_does_not_exist(instance_type)
219+
220+
if image_uri:
221+
self.image_uri = image_uri
222+
if not self.image_uri:
223+
self.image_uri = self.serving_image_uri(
224+
region_name=self.sagemaker_session.boto_session.region_name,
225+
instance_type=instance_type,
226+
)
227+
return super(HuggingFaceModel, self).register(
228+
content_types,
229+
response_types,
230+
inference_instances,
231+
transform_instances,
232+
model_package_name,
233+
model_package_group_name,
234+
image_uri,
235+
model_metrics,
236+
metadata_properties,
237+
marketplace_cert,
238+
approval_status,
239+
description,
240+
)
241+
242+
def prepare_container_def(self, instance_type=None, accelerator_type=None):
243+
"""A container definition with framework configuration set in model environment variables.
244+
245+
Args:
246+
instance_type (str): The EC2 instance type to deploy this Model to.
247+
For example, 'ml.p2.xlarge'.
248+
accelerator_type (str): The Elastic Inference accelerator type to
249+
deploy to the instance for loading and making inferences to the
250+
model.
251+
252+
Returns:
253+
dict[str, str]: A container definition object usable with the
254+
CreateModel API.
255+
"""
256+
deploy_image = self.image_uri
257+
if not deploy_image:
258+
if instance_type is None:
259+
raise ValueError(
260+
"Must supply either an instance type (for choosing CPU vs GPU) or an image URI."
261+
)
262+
263+
region_name = self.sagemaker_session.boto_session.region_name
264+
deploy_image = self.serving_image_uri(
265+
region_name, instance_type, accelerator_type=accelerator_type
266+
)
267+
268+
deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image)
269+
self._upload_code(deploy_key_prefix, repack=True)
270+
deploy_env = dict(self.env)
271+
deploy_env.update(self._framework_env_vars())
272+
273+
if self.model_server_workers:
274+
deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers)
275+
return sagemaker.container_def(
276+
deploy_image, self.repacked_model_data or self.model_data, deploy_env
277+
)
278+
279+
def serving_image_uri(self, region_name, instance_type, accelerator_type=None):
280+
"""Create a URI for the serving image.
281+
282+
Args:
283+
region_name (str): AWS region where the image is uploaded.
284+
instance_type (str): SageMaker instance type. Used to determine device type
285+
(cpu/gpu/family-specific optimized).
286+
accelerator_type (str): The Elastic Inference accelerator type to
287+
deploy to the instance for loading and making inferences to the
288+
model.
289+
290+
Returns:
291+
str: The appropriate image URI based on the given parameters.
292+
293+
"""
294+
if image_uris._processor(instance_type, ["cpu", "gpu"]) == "gpu":
295+
container_version = "cu110-ubuntu18.04"
296+
else:
297+
container_version = "ubuntu18.04"
298+
if self.tensorflow_version is not None: # pylint: disable=no-member
299+
base_framework_version = (
300+
f"tensorflow{self.tensorflow_version}" # pylint: disable=no-member
301+
)
302+
else:
303+
base_framework_version = f"pytorch{self.pytorch_version}" # pylint: disable=no-member
304+
return image_uris.retrieve(
305+
self._framework_name,
306+
region_name,
307+
version=self.framework_version,
308+
py_version=self.py_version,
309+
instance_type=instance_type,
310+
accelerator_type=accelerator_type,
311+
image_scope="inference",
312+
base_framework_version=base_framework_version,
313+
container_version=container_version,
314+
)

0 commit comments

Comments
 (0)