Skip to content

Commit 1c68699

Browse files
SDK changes for TRCOMP support
1 parent 479610d commit 1c68699

File tree

2 files changed

+22
-0
lines changed

2 files changed

+22
-0
lines changed

src/sagemaker/huggingface/training_compiler/config.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,16 @@ def validate(cls, estimator):
102102

103103
super(TrainingCompilerConfig, cls).validate(estimator)
104104

105+
if estimator.pytorch_version:
106+
if (Version(estimator.pytorch_version) in SpecifierSet("< 1.9")) or
107+
(Version(estimator.pytorch_version) in SpecifierSet("> 1.11")):
108+
error_helper_string = (
109+
"Training Compiler is only valid between HuggingFace PyTorch 1.9-1.11 "
110+
"for SageMaker Training Compiler."
111+
" Received pytorch_version={} which is unsupported."
112+
)
113+
raise ValueError(error_helper_string.format(estimator.pytorch_version))
114+
105115
if estimator.image_uri:
106116
error_helper_string = (
107117
"Overriding the image URI is currently not supported "

src/sagemaker/tensorflow/training_compiler/config.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ class TrainingCompilerConfig(BaseConfig):
2626

2727
SUPPORTED_INSTANCE_CLASS_PREFIXES = ["p3", "p3dn", "g4dn", "p4d", "g5"]
2828
MIN_SUPPORTED_VERSION = "2.9"
29+
MAX_SUPPORTED_VERSION = "2.11"
2930

3031
def __init__(self, enabled=True, debug=False):
3132
"""This class initializes a ``TrainingCompilerConfig`` instance.
@@ -102,6 +103,17 @@ def validate(cls, estimator):
102103
cls.MIN_SUPPORTED_VERSION, estimator.framework_version
103104
)
104105
raise ValueError(error_helper_string)
106+
elif Version(estimator.framework_version) in SpecifierSet(
107+
f"> {cls.MAX_SUPPORTED_VERSION}"
108+
):
109+
error_helper_string = (
110+
"SageMaker Training Compiler only supports TensorFlow version "
111+
"<= {} but received {}"
112+
)
113+
error_helper_string = error_helper_string.format(
114+
cls.MAX_SUPPORTED_VERSION, estimator.framework_version
115+
)
116+
raise ValueError(error_helper_string)
105117

106118
if estimator.distribution and "multi_worker_mirrored_strategy" in estimator.distribution:
107119
mwms_enabled = estimator.distribution.get("multi_worker_mirrored_strategy").get(

0 commit comments

Comments
 (0)