Skip to content

feature: SDK changes for TRCOMP support #3714

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Mar 16, 2023
11 changes: 11 additions & 0 deletions src/sagemaker/huggingface/training_compiler/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,17 @@ def validate(cls, estimator):

super(TrainingCompilerConfig, cls).validate(estimator)

if estimator.pytorch_version:
if (Version(estimator.pytorch_version) in SpecifierSet("< 1.9")) or (
Version(estimator.pytorch_version) in SpecifierSet("> 1.11")
):
error_helper_string = (
"Training Compiler is only valid between HuggingFace PyTorch 1.9-1.11 "
"for SageMaker Training Compiler."
" Received pytorch_version={} which is unsupported."
)
raise ValueError(error_helper_string.format(estimator.pytorch_version))

if estimator.image_uri:
error_helper_string = (
"Overriding the image URI is currently not supported "
Expand Down
12 changes: 12 additions & 0 deletions src/sagemaker/tensorflow/training_compiler/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class TrainingCompilerConfig(BaseConfig):

SUPPORTED_INSTANCE_CLASS_PREFIXES = ["p3", "p3dn", "g4dn", "p4d", "g5"]
MIN_SUPPORTED_VERSION = "2.9"
MAX_SUPPORTED_VERSION = "2.11"

def __init__(self, enabled=True, debug=False):
"""This class initializes a ``TrainingCompilerConfig`` instance.
Expand Down Expand Up @@ -102,6 +103,17 @@ def validate(cls, estimator):
cls.MIN_SUPPORTED_VERSION, estimator.framework_version
)
raise ValueError(error_helper_string)
if Version(estimator.framework_version) in SpecifierSet(
f"> {cls.MAX_SUPPORTED_VERSION}"
):
error_helper_string = (
"SageMaker Training Compiler only supports TensorFlow version "
"<= {} but received {}"
)
error_helper_string = error_helper_string.format(
cls.MAX_SUPPORTED_VERSION, estimator.framework_version
)
raise ValueError(error_helper_string)

if estimator.distribution and "multi_worker_mirrored_strategy" in estimator.distribution:
mwms_enabled = estimator.distribution.get("multi_worker_mirrored_strategy").get(
Expand Down