Skip to content

Commit cc93061

Browse files
icywang86ruiRui Wang Napieralski
and
Rui Wang Napieralski
authored
feature: add HuggingFace framework estimator (#2231)
Co-authored-by: Rui Wang Napieralski <[email protected]>
1 parent a5ebdca commit cc93061

File tree

14 files changed

+1411
-5
lines changed

14 files changed

+1411
-5
lines changed

src/sagemaker/estimator.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2331,6 +2331,22 @@ def training_image_uri(self):
23312331
distribution = self.distribution # pylint: disable=no-member
23322332
else:
23332333
distribution = None
2334+
2335+
if hasattr(self, "tensorflow_version") or hasattr(self, "pytorch_version"):
2336+
processor = image_uris._processor(self.instance_type, ["cpu", "gpu"])
2337+
container_version = "cu110-ubuntu18.04" if processor == "gpu" else None
2338+
if self.tensorflow_version is not None: # pylint: disable=no-member
2339+
base_framework_version = (
2340+
f"tensorflow{self.tensorflow_version}" # pylint: disable=no-member
2341+
)
2342+
else:
2343+
base_framework_version = (
2344+
f"pytorch{self.pytorch_version}" # pylint: disable=no-member
2345+
)
2346+
else:
2347+
container_version = None
2348+
base_framework_version = None
2349+
23342350
return image_uris.retrieve(
23352351
self._framework_name,
23362352
self.sagemaker_session.boto_region_name,
@@ -2339,6 +2355,8 @@ def training_image_uri(self):
23392355
py_version=self.py_version, # pylint: disable=no-member
23402356
image_scope="training",
23412357
distribution=distribution,
2358+
base_framework_version=base_framework_version,
2359+
container_version=container_version,
23422360
)
23432361

23442362
@classmethod

src/sagemaker/fw_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -288,9 +288,10 @@ def framework_name_from_image(image_uri):
288288
# We must support both the legacy and current image name format.
289289
name_pattern = re.compile(
290290
r"""^(?:sagemaker(?:-rl)?-)?
291-
(tensorflow|mxnet|chainer|pytorch|scikit-learn|xgboost)(?:-)?
291+
(tensorflow|mxnet|chainer|pytorch|scikit-learn|xgboost
292+
|huggingface-tensorflow|huggingface-pytorch)(?:-)?
292293
(scriptmode|training)?
293-
:(.*)-(.*?)-(py2|py3[67]?)$""",
294+
:(.*)-(.*?)-(py2|py3[67]?)(?:.*)$""",
294295
re.VERBOSE,
295296
)
296297
name_match = name_pattern.match(sagemaker_match.group(9))

src/sagemaker/huggingface/__init__.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Copyright 2018-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Placeholder docstring"""
14+
from __future__ import absolute_import
15+
16+
from sagemaker.huggingface.estimator import HuggingFace # noqa: F401
Lines changed: 278 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,278 @@
1+
# Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Placeholder docstring"""
14+
from __future__ import absolute_import
15+
16+
import logging
17+
import re
18+
19+
from sagemaker.deprecations import renamed_kwargs
20+
from sagemaker.estimator import Framework
21+
from sagemaker.fw_utils import (
22+
framework_name_from_image,
23+
warn_if_parameter_server_with_multi_gpu,
24+
validate_smdistributed,
25+
)
26+
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
27+
28+
logger = logging.getLogger("sagemaker")
29+
30+
31+
class HuggingFace(Framework):
32+
"""Handle training of custom HuggingFace code."""
33+
34+
_framework_name = "huggingface"
35+
36+
def __init__(
37+
self,
38+
py_version,
39+
entry_point,
40+
transformers_version=None,
41+
tensorflow_version=None,
42+
pytorch_version=None,
43+
source_dir=None,
44+
hyperparameters=None,
45+
image_uri=None,
46+
distribution=None,
47+
**kwargs
48+
):
49+
"""This ``Estimator`` executes a HuggingFace script in a managed execution environment.
50+
51+
The managed HuggingFace environment is an Amazon-built Docker container that executes
52+
functions defined in the supplied ``entry_point`` Python script within a SageMaker
53+
Training Job.
54+
55+
Training is started by calling
56+
:meth:`~sagemaker.amazon.estimator.Framework.fit` on this Estimator.
57+
58+
Args:
59+
py_version (str): Python version you want to use for executing your model training
60+
code. Defaults to ``None``. Required unless ``image_uri`` is provided. List
61+
of supported versions:
62+
https://github.com/aws/sagemaker-python-sdk#huggingface-sagemaker-estimators
63+
entry_point (str): Path (absolute or relative) to the Python source
64+
file which should be executed as the entry point to training.
65+
If ``source_dir`` is specified, then ``entry_point``
66+
must point to a file located at the root of ``source_dir``.
67+
transformers_version (str): Transformers version you want to use for
68+
executing your model training code. Defaults to ``None``. Required unless
69+
``image_uri`` is provided. List of supported versions:
70+
https://github.com/aws/sagemaker-python-sdk#huggingface-sagemaker-estimators.
71+
tensorflow_version (str): TensorFlow version you want to use for
72+
executing your model training code. Defaults to ``None``. Required unless
73+
``pytorch_version`` is provided. List of supported versions:
74+
https://github.com/aws/sagemaker-python-sdk#huggingface-sagemaker-estimators.
75+
pytorch_version (str): PyTorch version you want to use for
76+
executing your model training code. Defaults to ``None``. Required unless
77+
``tensorflow_version`` is provided. List of supported versions:
78+
https://github.com/aws/sagemaker-python-sdk#huggingface-sagemaker-estimators.
79+
source_dir (str): Path (absolute, relative or an S3 URI) to a directory
80+
with any other training source code dependencies aside from the entry
81+
point file (default: None). If ``source_dir`` is an S3 URI, it must
82+
point to a tar.gz file. Structure within this directory are preserved
83+
when training on Amazon SageMaker.
84+
hyperparameters (dict): Hyperparameters that will be used for
85+
training (default: None). The hyperparameters are made
86+
accessible as a dict[str, str] to the training code on
87+
SageMaker. For convenience, this accepts other types for keys
88+
and values, but ``str()`` will be called to convert them before
89+
training.
90+
image_uri (str): If specified, the estimator will use this image
91+
for training and hosting, instead of selecting the appropriate
92+
SageMaker official image based on framework_version and
93+
py_version. It can be an ECR url or dockerhub image and tag.
94+
Examples:
95+
* ``123412341234.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0``
96+
* ``custom-image:latest``
97+
98+
If ``framework_version`` or ``py_version`` are ``None``, then
99+
``image_uri`` is required. If also ``None``, then a ``ValueError``
100+
will be raised.
101+
distribution (dict): A dictionary with information on how to run distributed training
102+
(default: None). Currently, the following are supported:
103+
distributed training with parameter servers, SageMaker Distributed (SMD) Data
104+
and Model Parallelism, and MPI. SMD Model Parallelism can only be used with MPI.
105+
To enable parameter server use the following setup:
106+
107+
.. code:: python
108+
109+
{
110+
"parameter_server": {
111+
"enabled": True
112+
}
113+
}
114+
115+
To enable MPI:
116+
117+
.. code:: python
118+
119+
{
120+
"mpi": {
121+
"enabled": True
122+
}
123+
}
124+
125+
To enable SMDistributed Data Parallel or Model Parallel:
126+
127+
.. code:: python
128+
129+
{
130+
"smdistributed": {
131+
"dataparallel": {
132+
"enabled": True
133+
},
134+
"modelparallel": {
135+
"enabled": True,
136+
"parameters": {}
137+
}
138+
}
139+
}
140+
141+
**kwargs: Additional kwargs passed to the :class:`~sagemaker.estimator.Framework`
142+
constructor.
143+
144+
.. tip::
145+
146+
You can find additional parameters for initializing this class at
147+
:class:`~sagemaker.estimator.Framework` and
148+
:class:`~sagemaker.estimator.EstimatorBase`.
149+
"""
150+
self.framework_version = transformers_version
151+
self.py_version = py_version
152+
self.tensorflow_version = tensorflow_version
153+
self.pytorch_version = pytorch_version
154+
155+
self._validate_args(image_uri=image_uri)
156+
157+
if distribution is not None:
158+
instance_type = renamed_kwargs(
159+
"train_instance_type", "instance_type", kwargs.get("instance_type"), kwargs
160+
)
161+
162+
base_framework_name = "tensorflow" if tensorflow_version is not None else "pytorch"
163+
base_framework_version = (
164+
tensorflow_version if tensorflow_version is not None else pytorch_version
165+
)
166+
167+
validate_smdistributed(
168+
instance_type=instance_type,
169+
framework_name=base_framework_name,
170+
framework_version=base_framework_version,
171+
py_version=self.py_version,
172+
distribution=distribution,
173+
image_uri=image_uri,
174+
)
175+
176+
warn_if_parameter_server_with_multi_gpu(
177+
training_instance_type=instance_type, distribution=distribution
178+
)
179+
180+
if "enable_sagemaker_metrics" not in kwargs:
181+
kwargs["enable_sagemaker_metrics"] = True
182+
183+
super(HuggingFace, self).__init__(
184+
entry_point, source_dir, hyperparameters, image_uri=image_uri, **kwargs
185+
)
186+
self.distribution = distribution or {}
187+
188+
def _validate_args(self, image_uri):
189+
"""Placeholder docstring"""
190+
if image_uri is not None:
191+
return
192+
if self.framework_version is None and image_uri is None:
193+
raise ValueError(
194+
"transformers_version, and image_uri are both None. "
195+
"Specify either transformers_version or image_uri"
196+
)
197+
if self.tensorflow_version is not None and self.pytorch_version is not None:
198+
raise ValueError(
199+
"tensorflow_version and pytorch_version are both not None. "
200+
"Specify only tensorflow_version or pytorch_version."
201+
)
202+
if self.tensorflow_version is None and self.pytorch_version is None:
203+
raise ValueError(
204+
"tensorflow_version and pytorch_version are both None. "
205+
"Specify either tensorflow_version or pytorch_version."
206+
)
207+
208+
def hyperparameters(self):
209+
"""Return hyperparameters used by your custom PyTorch code during model training."""
210+
hyperparameters = super(HuggingFace, self).hyperparameters()
211+
additional_hyperparameters = self._distribution_configuration(
212+
distribution=self.distribution
213+
)
214+
hyperparameters.update(Framework._json_encode_hyperparameters(additional_hyperparameters))
215+
return hyperparameters
216+
217+
def create_model(
218+
self,
219+
model_server_workers=None,
220+
role=None,
221+
vpc_config_override=VPC_CONFIG_DEFAULT,
222+
entry_point=None,
223+
source_dir=None,
224+
dependencies=None,
225+
**kwargs
226+
):
227+
"""Placeholder docstring"""
228+
raise NotImplementedError("Creating model with HuggingFace training job is not supported.")
229+
230+
@classmethod
231+
def _prepare_init_params_from_job_description(cls, job_details, model_channel_name=None):
232+
"""Convert the job description to init params that can be handled by the class constructor.
233+
234+
Args:
235+
job_details: The returned job details from a describe_training_job
236+
API call.
237+
model_channel_name (str): Name of the channel where pre-trained
238+
model data will be downloaded.
239+
240+
Returns:
241+
dictionary: The transformed init_params
242+
"""
243+
init_params = super(HuggingFace, cls)._prepare_init_params_from_job_description(
244+
job_details, model_channel_name
245+
)
246+
image_uri = init_params.pop("image_uri")
247+
framework, py_version, tag, _ = framework_name_from_image(image_uri)
248+
249+
if tag is None:
250+
framework_version = None
251+
else:
252+
framework, pt_or_tf = framework.split("-")
253+
tag_pattern = re.compile("^(.*)-transformers(.*)-(cpu|gpu)-(py2|py3[67]?)$")
254+
tag_match = tag_pattern.match(tag)
255+
pt_or_tf_version = tag_match.group(1)
256+
framework_version = tag_match.group(2)
257+
if pt_or_tf == "pytorch":
258+
init_params["pytorch_version"] = pt_or_tf_version
259+
else:
260+
init_params["tensorflow_version"] = pt_or_tf_version
261+
262+
init_params["transformers_version"] = framework_version
263+
init_params["py_version"] = py_version
264+
265+
if not framework:
266+
# If we were unable to parse the framework name from the image it is not one of our
267+
# officially supported images, in this case just add the image to the init params.
268+
init_params["image_uri"] = image_uri
269+
return init_params
270+
271+
if framework != cls._framework_name:
272+
raise ValueError(
273+
"Training job: {} didn't use image for requested framework".format(
274+
job_details["TrainingJobName"]
275+
)
276+
)
277+
278+
return init_params

0 commit comments

Comments
 (0)