Skip to content

Commit 576bc29

Browse files
committed
change: Restructuring Training Compiler UI implementation
1 parent 764f61b commit 576bc29

File tree

5 files changed

+128
-40
lines changed

5 files changed

+128
-40
lines changed

src/sagemaker/huggingface/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,4 @@
1717
from sagemaker.huggingface.model import HuggingFaceModel, HuggingFacePredictor # noqa: F401
1818
from sagemaker.huggingface.processing import HuggingFaceProcessor # noqa:F401
1919

20-
from sagemaker.training_compiler.config import TrainingCompilerConfig # noqa: F401
20+
from sagemaker.huggingface.training_compiler.config import TrainingCompilerConfig # noqa: F401

src/sagemaker/huggingface/estimator.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from sagemaker.huggingface.model import HuggingFaceModel
2727
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
2828

29-
from sagemaker.training_compiler.config import TrainingCompilerConfig
29+
from sagemaker.huggingface.training_compiler.config import TrainingCompilerConfig
3030

3131
logger = logging.getLogger("sagemaker")
3232

@@ -199,11 +199,7 @@ def __init__(
199199
)
200200
raise ValueError(error_string)
201201
if compiler_config:
202-
compiler_config.validate(
203-
image_uri=image_uri,
204-
instance_type=instance_type,
205-
distribution=distribution,
206-
)
202+
compiler_config.validate(self)
207203

208204
self.distribution = distribution or {}
209205
self.compiler_config = compiler_config

src/sagemaker/huggingface/training_compiler/__init__.py

Whitespace-only changes.
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Configuration for the SageMaker Training Compiler."""
14+
from __future__ import absolute_import
15+
import logging
16+
17+
from sagemaker.training_compiler.config import TrainingCompilerConfig as BaseConfig
18+
19+
logger = logging.getLogger(__name__)
20+
21+
22+
class TrainingCompilerConfig(BaseConfig):
23+
"""The SageMaker Training Compiler configuration class."""
24+
25+
SUPPORTED_INSTANCE_CLASS_PREFIXES = ["p3", "g4dn", "p4"]
26+
27+
def __init__(
28+
self,
29+
enabled=True,
30+
debug=False,
31+
):
32+
"""This class initializes a ``TrainingCompilerConfig`` instance.
33+
34+
`Amazon SageMaker Training Compiler
35+
<https://docs.aws.amazon.com/sagemaker/latest/dg/training-compiler.html>`_
36+
is a feature of SageMaker Training
37+
and speeds up training jobs by optimizing model execution graphs.
38+
39+
You can compile Hugging Face models
40+
by passing the object of this configuration class to the ``compiler_config``
41+
parameter of the :class:`~sagemaker.huggingface.HuggingFace`
42+
estimator.
43+
44+
Args:
45+
enabled (bool): Optional. Switch to enable SageMaker Training Compiler.
46+
The default is ``True``.
47+
debug (bool): Optional. Whether to dump detailed logs for debugging.
48+
This comes with a potential performance slowdown.
49+
The default is ``False``.
50+
51+
**Example**: The following code shows the basic usage of the
52+
:class:`sagemaker.huggingface.TrainingCompilerConfig()` class
53+
to run a HuggingFace training job with the compiler.
54+
55+
.. code-block:: python
56+
57+
from sagemaker.huggingface import HuggingFace, TrainingCompilerConfig
58+
59+
huggingface_estimator=HuggingFace(
60+
...
61+
compiler_config=TrainingCompilerConfig()
62+
)
63+
64+
.. seealso::
65+
66+
For more information about how to enable SageMaker Training Compiler
67+
for various training settings such as using TensorFlow-based models,
68+
PyTorch-based models, and distributed training,
69+
see `Enable SageMaker Training Compiler
70+
<https://docs.aws.amazon.com/sagemaker/latest/dg/training-compiler-enable.html>`_
71+
in the `Amazon SageMaker Training Compiler developer guide
72+
<https://docs.aws.amazon.com/sagemaker/latest/dg/training-compiler.html>`_.
73+
74+
"""
75+
76+
super(TrainingCompilerConfig, self).__init__(enabled=enabled, debug=debug)
77+
78+
@classmethod
79+
def validate(
80+
cls,
81+
estimator,
82+
):
83+
"""Checks if SageMaker Training Compiler is configured correctly.
84+
85+
Args:
86+
estimator (str): A estimator object
87+
If SageMaker Training Compiler is enabled, it will validate whether
88+
the estimator is configured to be compatible with Training Compiler.
89+
90+
Raises:
91+
ValueError: Raised if the requested configuration is not compatible
92+
with SageMaker Training Compiler.
93+
"""
94+
95+
super(TrainingCompilerConfig, cls).validate(estimator)
96+
97+
if estimator.image_uri:
98+
error_helper_string = (
99+
"Overriding the image URI is currently not supported "
100+
"for SageMaker Training Compiler."
101+
"Specify the following parameters to run the Hugging Face training job "
102+
"with SageMaker Training Compiler enabled: "
103+
"transformer_version, tensorflow_version or pytorch_version, and compiler_config."
104+
)
105+
raise ValueError(error_helper_string)

src/sagemaker/training_compiler/config.py

Lines changed: 20 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -118,36 +118,25 @@ def _to_hyperparameter_dict(self):
118118
@classmethod
119119
def validate(
120120
cls,
121-
image_uri,
122-
instance_type,
123-
distribution,
121+
estimator,
124122
):
125123
"""Checks if SageMaker Training Compiler is configured correctly.
126124
127125
Args:
128-
image_uri (str): A string of a Docker image URI that's specified
129-
to :class:`~sagemaker.huggingface.HuggingFace`.
130-
If SageMaker Training Compiler is enabled, the HuggingFace estimator
131-
automatically chooses the right image URI. You cannot specify and override
132-
the image URI.
133-
instance_type (str): A string of the training instance type that's specified
134-
to :class:`~sagemaker.huggingface.HuggingFace`.
135-
The `validate` classmethod raises error
136-
if an instance type not in the ``SUPPORTED_INSTANCE_CLASS_PREFIXES`` list
137-
or ``local`` is passed to the `instance_type` parameter.
138-
distribution (dict): A dictionary of the distributed training option that's specified
139-
to :class:`~sagemaker.huggingface.HuggingFace`.
140-
SageMaker's distributed data parallel and model parallel libraries
141-
are currently not compatible
142-
with SageMaker Training Compiler.
126+
estimator (str): A estimator object
127+
If SageMaker Training Compiler is enabled, it will validate whether
128+
the estimator is configured to be compatible with Training Compiler.
129+
143130
144131
Raises:
145132
ValueError: Raised if the requested configuration is not compatible
146133
with SageMaker Training Compiler.
147134
"""
148135

149-
if "local" not in instance_type:
150-
requested_instance_class = instance_type.split(".")[1] # Expecting ml.class.size
136+
if "local" not in estimator.instance_type:
137+
requested_instance_class = estimator.instance_type.split(".")[
138+
1
139+
] # Expecting ml.class.size
151140
if not any(
152141
[
153142
requested_instance_class.startswith(i)
@@ -161,25 +150,23 @@ def validate(
161150
requested_instance_class, cls.SUPPORTED_INSTANCE_CLASS_PREFIXES
162151
)
163152
raise ValueError(error_helper_string)
164-
elif instance_type == "local":
153+
elif estimator.instance_type == "local":
165154
error_helper_string = (
166155
"The local mode is not supported by SageMaker Training Compiler."
167-
"It only supports the following GPU instances: p3, g4dn, and p4."
168-
)
169-
raise ValueError(error_helper_string)
170-
171-
if image_uri:
172-
error_helper_string = (
173-
"Overriding the image URI is currently not supported "
174-
"for SageMaker Training Compiler."
175-
"Specify the following parameters to run the Hugging Face training job "
176-
"with SageMaker Training Compiler enabled: "
177-
"transformer_version, tensorflow_version or pytorch_version, and compiler_config."
156+
f"It only supports the following GPU instances: {cls.SUPPORTED_INSTANCE_CLASS_PREFIXES}."
178157
)
179158
raise ValueError(error_helper_string)
180159

181-
if distribution and "smdistributed" in distribution:
160+
if estimator.distribution and "smdistributed" in estimator.distribution:
182161
raise ValueError(
183162
"SageMaker distributed training configuration is currently not compatible with "
184163
"SageMaker Training Compiler."
185164
)
165+
166+
if estimator.debugger_hook_config or (not estimator.disable_profiler):
167+
logger.warning(
168+
f"Using Debugger and/or Profiler with SageMaker Training Compiler causes poor "
169+
f"performance. Found debugger_hook_config={estimator.debugger_hook_config} "
170+
f"disable_profiler={estimator.disable_profiler}. Please set "
171+
f"debugger_hook_config=None and disable_profiler=True for optimal performance."
172+
)

0 commit comments

Comments
 (0)