Skip to content

Commit 7e2c7ab

Browse files
authored
feature: Inferentia Neuron support for HuggingFace (#2976)
1 parent dfc6eee commit 7e2c7ab

File tree

5 files changed

+250
-15
lines changed

5 files changed

+250
-15
lines changed

src/sagemaker/huggingface/model.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
2626
from sagemaker.predictor import Predictor
2727
from sagemaker.serializers import JSONSerializer
28+
from sagemaker.session import Session
2829

2930
logger = logging.getLogger("sagemaker")
3031

@@ -169,9 +170,125 @@ def __init__(
169170
super(HuggingFaceModel, self).__init__(
170171
model_data, image_uri, role, entry_point, predictor_cls=predictor_cls, **kwargs
171172
)
173+
self.sagemaker_session = self.sagemaker_session or Session()
172174

173175
self.model_server_workers = model_server_workers
174176

177+
# TODO: Remove the following function
178+
# botocore needs to add hugginface to the list of valid neo compilable frameworks.
179+
# Ideally with inferentia framewrok, call to .compile( ... ) method will create the image_uri.
180+
# currently, call to compile( ... ) method is causing `ValidationException`
181+
def deploy(
182+
self,
183+
initial_instance_count=None,
184+
instance_type=None,
185+
serializer=None,
186+
deserializer=None,
187+
accelerator_type=None,
188+
endpoint_name=None,
189+
tags=None,
190+
kms_key=None,
191+
wait=True,
192+
data_capture_config=None,
193+
async_inference_config=None,
194+
serverless_inference_config=None,
195+
**kwargs,
196+
):
197+
"""Deploy this ``Model`` to an ``Endpoint`` and optionally return a ``Predictor``.
198+
199+
Create a SageMaker ``Model`` and ``EndpointConfig``, and deploy an
200+
``Endpoint`` from this ``Model``. If ``self.predictor_cls`` is not None,
201+
this method returns a the result of invoking ``self.predictor_cls`` on
202+
the created endpoint name.
203+
204+
The name of the created model is accessible in the ``name`` field of
205+
this ``Model`` after deploy returns
206+
207+
The name of the created endpoint is accessible in the
208+
``endpoint_name`` field of this ``Model`` after deploy returns.
209+
210+
Args:
211+
initial_instance_count (int): The initial number of instances to run
212+
in the ``Endpoint`` created from this ``Model``. If not using
213+
serverless inference, then it need to be a number larger or equals
214+
to 1 (default: None)
215+
instance_type (str): The EC2 instance type to deploy this Model to.
216+
For example, 'ml.p2.xlarge', or 'local' for local mode. If not using
217+
serverless inference, then it is required to deploy a model.
218+
(default: None)
219+
serializer (:class:`~sagemaker.serializers.BaseSerializer`): A
220+
serializer object, used to encode data for an inference endpoint
221+
(default: None). If ``serializer`` is not None, then
222+
``serializer`` will override the default serializer. The
223+
default serializer is set by the ``predictor_cls``.
224+
deserializer (:class:`~sagemaker.deserializers.BaseDeserializer`): A
225+
deserializer object, used to decode data from an inference
226+
endpoint (default: None). If ``deserializer`` is not None, then
227+
``deserializer`` will override the default deserializer. The
228+
default deserializer is set by the ``predictor_cls``.
229+
accelerator_type (str): Type of Elastic Inference accelerator to
230+
deploy this model for model loading and inference, for example,
231+
'ml.eia1.medium'. If not specified, no Elastic Inference
232+
accelerator will be attached to the endpoint. For more
233+
information:
234+
https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html
235+
endpoint_name (str): The name of the endpoint to create (default:
236+
None). If not specified, a unique endpoint name will be created.
237+
tags (List[dict[str, str]]): The list of tags to attach to this
238+
specific endpoint.
239+
kms_key (str): The ARN of the KMS key that is used to encrypt the
240+
data on the storage volume attached to the instance hosting the
241+
endpoint.
242+
wait (bool): Whether the call should wait until the deployment of
243+
this model completes (default: True).
244+
data_capture_config (sagemaker.model_monitor.DataCaptureConfig): Specifies
245+
configuration related to Endpoint data capture for use with
246+
Amazon SageMaker Model Monitoring. Default: None.
247+
async_inference_config (sagemaker.model_monitor.AsyncInferenceConfig): Specifies
248+
configuration related to async endpoint. Use this configuration when trying
249+
to create async endpoint and make async inference. If empty config object
250+
passed through, will use default config to deploy async endpoint. Deploy a
251+
real-time endpoint if it's None. (default: None)
252+
serverless_inference_config (sagemaker.serverless.ServerlessInferenceConfig):
253+
Specifies configuration related to serverless endpoint. Use this configuration
254+
when trying to create serverless endpoint and make serverless inference. If
255+
empty object passed through, will use pre-defined values in
256+
``ServerlessInferenceConfig`` class to deploy serverless endpoint. Deploy an
257+
instance based endpoint if it's None. (default: None)
258+
Raises:
259+
ValueError: If arguments combination check failed in these circumstances:
260+
- If no role is specified or
261+
- If serverless inference config is not specified and instance type and instance
262+
count are also not specified or
263+
- If a wrong type of object is provided as serverless inference config or async
264+
inference config
265+
Returns:
266+
callable[string, sagemaker.session.Session] or None: Invocation of
267+
``self.predictor_cls`` on the created endpoint name, if ``self.predictor_cls``
268+
is not None. Otherwise, return None.
269+
"""
270+
271+
if not self.image_uri and instance_type.startswith("ml.inf"):
272+
self.image_uri = self.serving_image_uri(
273+
region_name=self.sagemaker_session.boto_session.region_name,
274+
instance_type=instance_type,
275+
)
276+
277+
return super(HuggingFaceModel, self).deploy(
278+
initial_instance_count,
279+
instance_type,
280+
serializer,
281+
deserializer,
282+
accelerator_type,
283+
endpoint_name,
284+
tags,
285+
kms_key,
286+
wait,
287+
data_capture_config,
288+
async_inference_config,
289+
serverless_inference_config,
290+
)
291+
175292
def register(
176293
self,
177294
content_types,
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
{
2+
"inference": {
3+
"processors": ["inf"],
4+
"version_aliases": {"4.12": "4.12.3"},
5+
"versions": {
6+
"4.12.3": {
7+
"version_aliases": {"pytorch1.9": "pytorch1.9.1"},
8+
"pytorch1.9.1": {
9+
"py_versions": ["py37"],
10+
"repository": "huggingface-pytorch-inference-neuron",
11+
"registries": {
12+
"af-south-1": "626614931356",
13+
"ap-east-1": "871362719292",
14+
"ap-northeast-1": "763104351884",
15+
"ap-northeast-2": "763104351884",
16+
"ap-northeast-3": "364406365360",
17+
"ap-south-1": "763104351884",
18+
"ap-southeast-1": "763104351884",
19+
"ap-southeast-2": "763104351884",
20+
"ca-central-1": "763104351884",
21+
"cn-north-1": "727897471807",
22+
"cn-northwest-1": "727897471807",
23+
"eu-central-1": "763104351884",
24+
"eu-north-1": "763104351884",
25+
"eu-west-1": "763104351884",
26+
"eu-west-2": "763104351884",
27+
"eu-west-3": "763104351884",
28+
"eu-south-1": "692866216735",
29+
"me-south-1": "217643126080",
30+
"sa-east-1": "763104351884",
31+
"us-east-1": "763104351884",
32+
"us-east-2": "763104351884",
33+
"us-gov-west-1": "442386744353",
34+
"us-iso-east-1": "886529160074",
35+
"us-west-1": "763104351884",
36+
"us-west-2": "763104351884"
37+
},
38+
"container_version": {"inf": "ubuntu18.04"},
39+
"sdk_versions": ["sdk1.17.1"]
40+
}
41+
}
42+
}
43+
}
44+
}

src/sagemaker/image_uris.py

Lines changed: 53 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from sagemaker.spark import defaults
2525
from sagemaker.jumpstart import artifacts
2626

27-
2827
logger = logging.getLogger(__name__)
2928

3029
ECR_URI_TEMPLATE = "{registry}.dkr.{hostname}/{repository}"
@@ -47,6 +46,8 @@ def retrieve(
4746
model_version=None,
4847
tolerate_vulnerable_model=False,
4948
tolerate_deprecated_model=False,
49+
sdk_version=None,
50+
inference_tool=None,
5051
) -> str:
5152
"""Retrieves the ECR URI for the Docker image matching the given arguments.
5253
@@ -88,6 +89,11 @@ def retrieve(
8889
tolerate_deprecated_model (bool): True if deprecated versions of model specifications
8990
should be tolerated without an exception raised. If False, raises an exception
9091
if the version of the model is deprecated. (Default: False).
92+
sdk_version (str): the version of python-sdk that will be used in the image retrieval.
93+
(default: None).
94+
inference_tool (str): the tool that will be used to aid in the inference.
95+
Valid values: "neuron, None"
96+
(default: None).
9197
9298
Returns:
9399
str: The ECR URI for the corresponding SageMaker Docker image.
@@ -100,7 +106,6 @@ def retrieve(
100106
DeprecatedJumpStartModelError: If the version of the model is deprecated.
101107
"""
102108
if is_jumpstart_model_input(model_id, model_version):
103-
104109
return artifacts._retrieve_image_uri(
105110
model_id,
106111
model_version,
@@ -118,9 +123,13 @@ def retrieve(
118123
tolerate_vulnerable_model,
119124
tolerate_deprecated_model,
120125
)
121-
122126
if training_compiler_config is None:
123-
config = _config_for_framework_and_scope(framework, image_scope, accelerator_type)
127+
_framework = framework
128+
if framework == HUGGING_FACE_FRAMEWORK:
129+
inference_tool = _get_inference_tool(inference_tool, instance_type)
130+
if inference_tool == "neuron":
131+
_framework = f"{framework}-{inference_tool}"
132+
config = _config_for_framework_and_scope(_framework, image_scope, accelerator_type)
124133
elif framework == HUGGING_FACE_FRAMEWORK:
125134
config = _config_for_framework_and_scope(
126135
framework + "-training-compiler", image_scope, accelerator_type
@@ -129,6 +138,7 @@ def retrieve(
129138
raise ValueError(
130139
"Unsupported Configuration: Training Compiler is only supported with HuggingFace"
131140
)
141+
132142
original_version = version
133143
version = _validate_version_and_set_if_needed(version, config, framework)
134144
version_config = config["versions"][_version_for_config(version, config)]
@@ -138,7 +148,6 @@ def retrieve(
138148
full_base_framework_version = version_config["version_aliases"].get(
139149
base_framework_version, base_framework_version
140150
)
141-
142151
_validate_arg(full_base_framework_version, list(version_config.keys()), "base framework")
143152
version_config = version_config.get(full_base_framework_version)
144153

@@ -161,25 +170,37 @@ def retrieve(
161170
pt_or_tf_version = (
162171
re.compile("^(pytorch|tensorflow)(.*)$").match(base_framework_version).group(2)
163172
)
164-
165173
_version = original_version
174+
166175
if repo in [
167176
"huggingface-pytorch-trcomp-training",
168177
"huggingface-tensorflow-trcomp-training",
169178
]:
170179
_version = version
180+
if repo in ["huggingface-pytorch-inference-neuron"]:
181+
if not sdk_version:
182+
sdk_version = _get_latest_versions(version_config["sdk_versions"])
183+
container_version = sdk_version + "-" + container_version
184+
if config.get("version_aliases").get(original_version):
185+
_version = config.get("version_aliases")[original_version]
186+
if (
187+
config.get("versions", {})
188+
.get(_version, {})
189+
.get("version_aliases", {})
190+
.get(base_framework_version, {})
191+
):
192+
_base_framework_version = config.get("versions")[_version]["version_aliases"][
193+
base_framework_version
194+
]
195+
pt_or_tf_version = (
196+
re.compile("^(pytorch|tensorflow)(.*)$").match(_base_framework_version).group(2)
197+
)
171198

172199
tag_prefix = f"{pt_or_tf_version}-transformers{_version}"
173-
174200
else:
175201
tag_prefix = version_config.get("tag_prefix", version)
176202

177-
tag = _format_tag(
178-
tag_prefix,
179-
processor,
180-
py_version,
181-
container_version,
182-
)
203+
tag = _format_tag(tag_prefix, processor, py_version, container_version, inference_tool)
183204

184205
if _should_auto_select_container_version(instance_type, distribution):
185206
container_versions = {
@@ -248,6 +269,20 @@ def config_for_framework(framework):
248269
return json.load(f)
249270

250271

272+
def _get_inference_tool(inference_tool, instance_type):
273+
"""Extract the inference tool name from instance type."""
274+
if not inference_tool and instance_type:
275+
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
276+
if match and match[1].startswith("inf"):
277+
return "neuron"
278+
return inference_tool
279+
280+
281+
def _get_latest_versions(list_of_versions):
282+
"""Extract the latest version from the input list of available versions."""
283+
return sorted(list_of_versions, reverse=True)[0]
284+
285+
251286
def _validate_accelerator_type(accelerator_type):
252287
"""Raises a ``ValueError`` if ``accelerator_type`` is invalid."""
253288
if not accelerator_type.startswith("ml.eia") and accelerator_type != "local_sagemaker_notebook":
@@ -310,6 +345,8 @@ def _processor(instance_type, available_processors):
310345

311346
if instance_type.startswith("local"):
312347
processor = "cpu" if instance_type == "local" else "gpu"
348+
elif instance_type.startswith("neuron"):
349+
processor = "neuron"
313350
else:
314351
# looks for either "ml.<family>.<size>" or "ml_<family>"
315352
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
@@ -387,8 +424,10 @@ def _validate_arg(arg, available_options, arg_name):
387424
)
388425

389426

390-
def _format_tag(tag_prefix, processor, py_version, container_version):
427+
def _format_tag(tag_prefix, processor, py_version, container_version, inference_tool=None):
391428
"""Creates a tag for the image URI."""
429+
if inference_tool:
430+
return "-".join(x for x in (tag_prefix, inference_tool, py_version, container_version) if x)
392431
return "-".join(x for x in (tag_prefix, processor, py_version, container_version) if x)
393432

394433

tests/conftest.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,21 @@ def huggingface_tensorflow_latest_training_py_version():
269269
return "py37"
270270

271271

272+
@pytest.fixture(scope="module")
273+
def huggingface_neuron_latest_inference_pytorch_version():
274+
return "1.9"
275+
276+
277+
@pytest.fixture(scope="module")
278+
def huggingface_neuron_latest_inference_transformer_version():
279+
return "4.12"
280+
281+
282+
@pytest.fixture(scope="module")
283+
def huggingface_neuron_latest_inference_py_version():
284+
return "py37"
285+
286+
272287
@pytest.fixture(scope="module")
273288
def pytorch_eia_py_version():
274289
return "py3"

tests/unit/sagemaker/huggingface/test_estimator.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import pytest
2020
from mock import MagicMock, Mock, patch
2121

22-
from sagemaker.huggingface import HuggingFace
22+
from sagemaker.huggingface import HuggingFace, HuggingFaceModel
2323

2424
from .huggingface_utils import get_full_gpu_image_uri, GPU_INSTANCE_TYPE, REGION
2525

@@ -252,6 +252,26 @@ def test_huggingface(
252252
assert actual_train_args == expected_train_args
253253

254254

255+
def test_huggingface_neuron(
256+
sagemaker_session,
257+
huggingface_neuron_latest_inference_pytorch_version,
258+
huggingface_neuron_latest_inference_transformer_version,
259+
huggingface_neuron_latest_inference_py_version,
260+
):
261+
262+
inputs = "s3://mybucket/train"
263+
huggingface_model = HuggingFaceModel(
264+
model_data=inputs,
265+
transformers_version=huggingface_neuron_latest_inference_transformer_version,
266+
role=ROLE,
267+
sagemaker_session=sagemaker_session,
268+
pytorch_version=huggingface_neuron_latest_inference_pytorch_version,
269+
py_version=huggingface_neuron_latest_inference_py_version,
270+
)
271+
container = huggingface_model.prepare_container_def("ml.inf.xlarge")
272+
assert container["Image"]
273+
274+
255275
def test_attach(
256276
sagemaker_session,
257277
huggingface_training_version,

0 commit comments

Comments
 (0)