Skip to content

Commit b30e771

Browse files
committed
feature: Adding Training Compiler support for TensorFlow estimator
1 parent 576bc29 commit b30e771

File tree

5 files changed

+127
-9
lines changed

5 files changed

+127
-9
lines changed

src/sagemaker/image_uris.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -134,21 +134,18 @@ def retrieve(
134134
tolerate_vulnerable_model,
135135
tolerate_deprecated_model,
136136
)
137-
if training_compiler_config is None:
137+
138+
if training_compiler_config and (framework == HUGGING_FACE_FRAMEWORK):
139+
config = _config_for_framework_and_scope(
140+
framework + "-training-compiler", image_scope, accelerator_type
141+
)
142+
else:
138143
_framework = framework
139144
if framework == HUGGING_FACE_FRAMEWORK:
140145
inference_tool = _get_inference_tool(inference_tool, instance_type)
141146
if inference_tool == "neuron":
142147
_framework = f"{framework}-{inference_tool}"
143148
config = _config_for_framework_and_scope(_framework, image_scope, accelerator_type)
144-
elif framework == HUGGING_FACE_FRAMEWORK:
145-
config = _config_for_framework_and_scope(
146-
framework + "-training-compiler", image_scope, accelerator_type
147-
)
148-
else:
149-
raise ValueError(
150-
"Unsupported Configuration: Training Compiler is only supported with HuggingFace"
151-
)
152149

153150
original_version = version
154151
version = _validate_version_and_set_if_needed(version, config, framework)

src/sagemaker/tensorflow/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,5 @@
1616
from sagemaker.tensorflow.estimator import TensorFlow # noqa: F401 (imported but unused)
1717
from sagemaker.tensorflow.model import TensorFlowModel, TensorFlowPredictor # noqa: F401
1818
from sagemaker.tensorflow.processing import TensorFlowProcessor # noqa: F401
19+
20+
from sagemaker.tensorflow.training_compiler.config import TrainingCompilerConfig # noqa: F401

src/sagemaker/tensorflow/estimator.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from sagemaker.transformer import Transformer
2727
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
2828
from sagemaker.workflow import is_pipeline_variable
29+
from sagemaker.tensorflow.training_compiler.config import TrainingCompilerConfig
2930

3031
logger = logging.getLogger("sagemaker")
3132

@@ -45,6 +46,7 @@ def __init__(
4546
model_dir=None,
4647
image_uri=None,
4748
distribution=None,
49+
compiler_config=None,
4850
**kwargs
4951
):
5052
"""Initialize a ``TensorFlow`` estimator.
@@ -157,6 +159,8 @@ def __init__(
157159
158160
To learn more, see `Training with parameter servers
159161
<https://sagemaker.readthedocs.io/en/stable/frameworks/tensorflow/using_tf.html#training-with-parameter-servers>`_.
162+
compiler_config (:class:`~sagemaker.tensorflow.TrainingCompilerConfig`):
163+
Configures SageMaker Training Compiler to accelerate training.
160164
161165
**kwargs: Additional kwargs passed to the Framework constructor.
162166
@@ -202,6 +206,16 @@ def __init__(
202206
self.distribution = distribution or {}
203207

204208
self._validate_args(py_version=py_version)
209+
if compiler_config is not None:
210+
if not isinstance(compiler_config, TrainingCompilerConfig):
211+
error_string = (
212+
f"Expected instance of type {TrainingCompilerConfig}"
213+
f"for argument compiler_config. "
214+
f"Instead got {type(compiler_config)}"
215+
)
216+
raise ValueError(error_string)
217+
if compiler_config:
218+
compiler_config.validate(self)
205219

206220
def _validate_args(self, py_version):
207221
"""Placeholder docstring"""

src/sagemaker/tensorflow/training_compiler/__init__.py

Whitespace-only changes.
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Configuration for the SageMaker Training Compiler."""
14+
from __future__ import absolute_import
15+
import logging
16+
from packaging import version
17+
18+
from sagemaker.training_compiler.config import TrainingCompilerConfig as BaseConfig
19+
20+
logger = logging.getLogger(__name__)
21+
22+
23+
class TrainingCompilerConfig(BaseConfig):
24+
"""The SageMaker Training Compiler configuration class."""
25+
26+
SUPPORTED_INSTANCE_CLASS_PREFIXES = ["p3", "g4dn", "p4", "g5"]
27+
MIN_SUPPORTED_VERSION = version.parse("2.9")
28+
29+
def __init__(
30+
self,
31+
enabled=True,
32+
debug=False,
33+
):
34+
"""This class initializes a ``TrainingCompilerConfig`` instance.
35+
36+
`Amazon SageMaker Training Compiler
37+
<https://docs.aws.amazon.com/sagemaker/latest/dg/training-compiler.html>`_
38+
is a feature of SageMaker Training
39+
and speeds up training jobs by optimizing model execution graphs.
40+
41+
You can compile TensorFlow models
42+
by passing the object of this configuration class to the ``compiler_config``
43+
parameter of the :class:`~sagemaker.tensorflow.TensorFlow`
44+
estimator.
45+
46+
Args:
47+
enabled (bool): Optional. Switch to enable SageMaker Training Compiler.
48+
The default is ``True``.
49+
debug (bool): Optional. Whether to dump detailed logs for debugging.
50+
This comes with a potential performance slowdown.
51+
The default is ``False``.
52+
53+
**Example**: The following code shows the basic usage of the
54+
:class:`sagemaker.tensorflow.TrainingCompilerConfig()` class
55+
to run a TensorFlow training job with the compiler.
56+
57+
.. code-block:: python
58+
59+
from sagemaker.tensorflow import TensorFlow, TrainingCompilerConfig
60+
61+
tensorflow_estimator=TensorFlow(
62+
...
63+
compiler_config=TrainingCompilerConfig()
64+
)
65+
66+
.. seealso::
67+
68+
For more information about how to enable SageMaker Training Compiler
69+
for various training settings such as using TensorFlow-based models,
70+
PyTorch-based models, and distributed training,
71+
see `Enable SageMaker Training Compiler
72+
<https://docs.aws.amazon.com/sagemaker/latest/dg/training-compiler-enable.html>`_
73+
in the `Amazon SageMaker Training Compiler developer guide
74+
<https://docs.aws.amazon.com/sagemaker/latest/dg/training-compiler.html>`_.
75+
76+
"""
77+
78+
super(TrainingCompilerConfig, self).__init__(enabled=enabled, debug=debug)
79+
80+
@classmethod
81+
def validate(
82+
cls,
83+
estimator,
84+
):
85+
"""Checks if SageMaker Training Compiler is configured correctly.
86+
87+
Args:
88+
estimator (str): A estimator object
89+
If SageMaker Training Compiler is enabled, it will validate whether
90+
the estimator is configured to be compatible with Training Compiler.
91+
92+
Raises:
93+
ValueError: Raised if the requested configuration is not compatible
94+
with SageMaker Training Compiler.
95+
"""
96+
97+
super(TrainingCompilerConfig, cls).validate(estimator)
98+
99+
if estimator.framework_version:
100+
if version.parse(estimator.framework_version) < cls.MIN_SUPPORTED_VERSION:
101+
error_helper_string = (
102+
f"SageMaker Training Compiler only supports TensorFlow version "
103+
f">= {cls.MIN_SUPPORTED_VERSION} but received {estimator.framework_version}"
104+
)
105+
raise ValueError(error_helper_string)

0 commit comments

Comments
 (0)