-
Notifications
You must be signed in to change notification settings - Fork 1.2k
feature: Adding Training Compiler support for TensorFlow estimator starting TF 2.9 #3156
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
b4adec2
change: Restructuring Training Compiler UI implementation
Lokiiiiii d818176
feature: Adding Training Compiler support for TensorFlow estimator
Lokiiiiii 5f11033
change: Tests targetting Training Compiler in TensorFlow estimator
Lokiiiiii 846bb82
fix: linting in training compiler files
Lokiiiiii 5a9db84
fix: liniting in tensorflow estimator
Lokiiiiii 2912b47
fix: syntax error in trcomp tests
Lokiiiiii 6b8007e
fix: logic error in trcomp initialization in HF estimator
Lokiiiiii 85af72d
fix: logic error in trcomp test for TF estimator
Lokiiiiii 79efc4f
fix: logic error in version comparison
Lokiiiiii 239d639
fix: syntax error in TF trcomp
Lokiiiiii 2b8de05
change: documentation updates for trcomp
Lokiiiiii 9ea141c
Apply documentation suggestions from code review for trcomp
Lokiiiiii 26729be
update: documentation update for trcomp
Lokiiiiii 02697a0
Adding tests for the TF trcomp BYOC path
Lokiiiiii 84a6b00
linting trcomp config
Lokiiiiii 9f4cf52
Adding logic to convert compiler_config to hyperparameters
Lokiiiiii a4d86c1
Fixing trcomp tensorflow tests
Lokiiiiii 80a223a
Fixing logic error in training_compiler supported version
Lokiiiiii df4d6d9
Fixing logic error in training_compiler
Lokiiiiii 56ca880
Fixing logic error in training_compiler
Lokiiiiii File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
"""Configuration for the SageMaker Training Compiler.""" | ||
from __future__ import absolute_import | ||
import logging | ||
|
||
from sagemaker.training_compiler.config import TrainingCompilerConfig as BaseConfig | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class TrainingCompilerConfig(BaseConfig): | ||
"""The SageMaker Training Compiler configuration class.""" | ||
|
||
SUPPORTED_INSTANCE_CLASS_PREFIXES = ["p3", "g4dn", "p4"] | ||
|
||
def __init__( | ||
self, | ||
enabled=True, | ||
debug=False, | ||
): | ||
"""This class initializes a ``TrainingCompilerConfig`` instance. | ||
|
||
`Amazon SageMaker Training Compiler | ||
<https://docs.aws.amazon.com/sagemaker/latest/dg/training-compiler.html>`_ | ||
is a feature of SageMaker Training | ||
and speeds up training jobs by optimizing model execution graphs. | ||
|
||
You can compile Hugging Face models | ||
by passing the object of this configuration class to the ``compiler_config`` | ||
parameter of the :class:`~sagemaker.huggingface.HuggingFace` | ||
estimator. | ||
|
||
Args: | ||
enabled (bool): Optional. Switch to enable SageMaker Training Compiler. | ||
The default is ``True``. | ||
debug (bool): Optional. Whether to dump detailed logs for debugging. | ||
This comes with a potential performance slowdown. | ||
The default is ``False``. | ||
|
||
**Example**: The following code shows the basic usage of the | ||
:class:`sagemaker.huggingface.TrainingCompilerConfig()` class | ||
to run a HuggingFace training job with the compiler. | ||
|
||
.. code-block:: python | ||
|
||
from sagemaker.huggingface import HuggingFace, TrainingCompilerConfig | ||
|
||
huggingface_estimator=HuggingFace( | ||
... | ||
compiler_config=TrainingCompilerConfig() | ||
) | ||
|
||
.. seealso:: | ||
|
||
For more information about how to enable SageMaker Training Compiler | ||
for various training settings such as using TensorFlow-based models, | ||
PyTorch-based models, and distributed training, | ||
see `Enable SageMaker Training Compiler | ||
<https://docs.aws.amazon.com/sagemaker/latest/dg/training-compiler-enable.html>`_ | ||
in the `Amazon SageMaker Training Compiler developer guide | ||
<https://docs.aws.amazon.com/sagemaker/latest/dg/training-compiler.html>`_. | ||
|
||
""" | ||
|
||
super(TrainingCompilerConfig, self).__init__(enabled=enabled, debug=debug) | ||
|
||
@classmethod | ||
def validate( | ||
cls, | ||
estimator, | ||
): | ||
"""Checks if SageMaker Training Compiler is configured correctly. | ||
|
||
Args: | ||
estimator (str): A estimator object | ||
If SageMaker Training Compiler is enabled, it will validate whether | ||
the estimator is configured to be compatible with Training Compiler. | ||
|
||
Raises: | ||
ValueError: Raised if the requested configuration is not compatible | ||
with SageMaker Training Compiler. | ||
""" | ||
|
||
super(TrainingCompilerConfig, cls).validate(estimator) | ||
|
||
if estimator.image_uri: | ||
error_helper_string = ( | ||
"Overriding the image URI is currently not supported " | ||
"for SageMaker Training Compiler." | ||
"Specify the following parameters to run the Hugging Face training job " | ||
"with SageMaker Training Compiler enabled: " | ||
"transformer_version, tensorflow_version or pytorch_version, and compiler_config." | ||
) | ||
raise ValueError(error_helper_string) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
"""Configuration for the SageMaker Training Compiler.""" | ||
from __future__ import absolute_import | ||
import logging | ||
from packaging.specifiers import SpecifierSet | ||
from packaging.version import Version | ||
|
||
from sagemaker.training_compiler.config import TrainingCompilerConfig as BaseConfig | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class TrainingCompilerConfig(BaseConfig): | ||
"""The SageMaker Training Compiler configuration class.""" | ||
|
||
SUPPORTED_INSTANCE_CLASS_PREFIXES = ["p3", "g4dn", "p4", "g5"] | ||
MIN_SUPPORTED_VERSION = "2.9" | ||
|
||
def __init__( | ||
self, | ||
enabled=True, | ||
debug=False, | ||
): | ||
"""This class initializes a ``TrainingCompilerConfig`` instance. | ||
|
||
`Amazon SageMaker Training Compiler | ||
<https://docs.aws.amazon.com/sagemaker/latest/dg/training-compiler.html>`_ | ||
is a feature of SageMaker Training | ||
and speeds up training jobs by optimizing model execution graphs. | ||
|
||
You can compile TensorFlow models | ||
by passing the object of this configuration class to the ``compiler_config`` | ||
parameter of the :class:`~sagemaker.tensorflow.TensorFlow` | ||
estimator. | ||
|
||
Args: | ||
enabled (bool): Optional. Switch to enable SageMaker Training Compiler. | ||
The default is ``True``. | ||
debug (bool): Optional. Whether to dump detailed logs for debugging. | ||
This comes with a potential performance slowdown. | ||
The default is ``False``. | ||
|
||
**Example**: The following code shows the basic usage of the | ||
:class:`sagemaker.tensorflow.TrainingCompilerConfig()` class | ||
to run a TensorFlow training job with the compiler. | ||
|
||
.. code-block:: python | ||
|
||
from sagemaker.tensorflow import TensorFlow, TrainingCompilerConfig | ||
|
||
tensorflow_estimator=TensorFlow( | ||
... | ||
compiler_config=TrainingCompilerConfig() | ||
) | ||
|
||
.. seealso:: | ||
|
||
For more information about how to enable SageMaker Training Compiler | ||
for various training settings such as using TensorFlow-based models, | ||
PyTorch-based models, and distributed training, | ||
see `Enable SageMaker Training Compiler | ||
<https://docs.aws.amazon.com/sagemaker/latest/dg/training-compiler-enable.html>`_ | ||
in the `Amazon SageMaker Training Compiler developer guide | ||
<https://docs.aws.amazon.com/sagemaker/latest/dg/training-compiler.html>`_. | ||
|
||
""" | ||
|
||
super(TrainingCompilerConfig, self).__init__(enabled=enabled, debug=debug) | ||
|
||
@classmethod | ||
def validate( | ||
cls, | ||
estimator, | ||
): | ||
"""Checks if SageMaker Training Compiler is configured correctly. | ||
|
||
Args: | ||
estimator (str): A estimator object | ||
If SageMaker Training Compiler is enabled, it will validate whether | ||
the estimator is configured to be compatible with Training Compiler. | ||
|
||
Raises: | ||
ValueError: Raised if the requested configuration is not compatible | ||
with SageMaker Training Compiler. | ||
""" | ||
|
||
super(TrainingCompilerConfig, cls).validate(estimator) | ||
|
||
if estimator.framework_version: | ||
if Version(estimator.framework_version) in SpecifierSet( | ||
f"< {cls.MIN_SUPPORTED_VERSION}" | ||
): | ||
error_helper_string = ( | ||
"SageMaker Training Compiler only supports TensorFlow version " | ||
">= {} but received {}" | ||
) | ||
error_helper_string = error_helper_string.format( | ||
cls.MIN_SUPPORTED_VERSION, estimator.framework_version | ||
) | ||
raise ValueError(error_helper_string) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.