Skip to content

Commit 071303c

Browse files
edit frameversion comment and specifier set function
1 parent c1955db commit 071303c

File tree

2 files changed

+7
-17
lines changed

2 files changed

+7
-17
lines changed

src/sagemaker/huggingface/training_compiler/config.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,7 @@ def validate(cls, estimator):
107107
Version(estimator.pytorch_version) in SpecifierSet("> 1.11")
108108
):
109109
error_helper_string = (
110-
"Training Compiler is only valid between HuggingFace PyTorch 1.9-1.11 "
111-
"for SageMaker Training Compiler."
110+
"SageMaker Training Compiler is only supported with HuggingFace PyTorch 1.9-1.11 "
112111
" Received pytorch_version={} which is unsupported."
113112
)
114113
raise ValueError(error_helper_string.format(estimator.pytorch_version))

src/sagemaker/tensorflow/training_compiler/config.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -92,26 +92,17 @@ def validate(cls, estimator):
9292
super(TrainingCompilerConfig, cls).validate(estimator)
9393

9494
if estimator.framework_version:
95-
if Version(estimator.framework_version) in SpecifierSet(
96-
f"< {cls.MIN_SUPPORTED_VERSION}"
95+
if not Version(estimator.framework_version) in SpecifierSet(
96+
f">= {cls.MIN_SUPPORTED_VERSION}", f"<= {cls.MAX_SUPPORTED_VERSION}"
9797
):
9898
error_helper_string = (
9999
"SageMaker Training Compiler only supports TensorFlow version "
100-
">= {} but received {}"
100+
"between {} to {} but received {}"
101101
)
102102
error_helper_string = error_helper_string.format(
103-
cls.MIN_SUPPORTED_VERSION, estimator.framework_version
104-
)
105-
raise ValueError(error_helper_string)
106-
if Version(estimator.framework_version) in SpecifierSet(
107-
f"> {cls.MAX_SUPPORTED_VERSION}"
108-
):
109-
error_helper_string = (
110-
"SageMaker Training Compiler only supports TensorFlow version "
111-
"<= {} but received {}"
112-
)
113-
error_helper_string = error_helper_string.format(
114-
cls.MAX_SUPPORTED_VERSION, estimator.framework_version
103+
cls.MIN_SUPPORTED_VERSION,
104+
cls.MAX_SUPPORTED_VERSION,
105+
estimator.framework_version,
115106
)
116107
raise ValueError(error_helper_string)
117108

0 commit comments

Comments
 (0)