@@ -92,26 +92,17 @@ def validate(cls, estimator):
92
92
super (TrainingCompilerConfig , cls ).validate (estimator )
93
93
94
94
if estimator .framework_version :
95
- if Version (estimator .framework_version ) in SpecifierSet (
96
- f"< { cls .MIN_SUPPORTED_VERSION } "
95
+ if not Version (estimator .framework_version ) in SpecifierSet (
96
+ f">= { cls .MIN_SUPPORTED_VERSION } " , f"<= { cls . MAX_SUPPORTED_VERSION } "
97
97
):
98
98
error_helper_string = (
99
99
"SageMaker Training Compiler only supports TensorFlow version "
100
- ">= {} but received {}"
100
+ "between {} to {} but received {}"
101
101
)
102
102
error_helper_string = error_helper_string .format (
103
- cls .MIN_SUPPORTED_VERSION , estimator .framework_version
104
- )
105
- raise ValueError (error_helper_string )
106
- if Version (estimator .framework_version ) in SpecifierSet (
107
- f"> { cls .MAX_SUPPORTED_VERSION } "
108
- ):
109
- error_helper_string = (
110
- "SageMaker Training Compiler only supports TensorFlow version "
111
- "<= {} but received {}"
112
- )
113
- error_helper_string = error_helper_string .format (
114
- cls .MAX_SUPPORTED_VERSION , estimator .framework_version
103
+ cls .MIN_SUPPORTED_VERSION ,
104
+ cls .MAX_SUPPORTED_VERSION ,
105
+ estimator .framework_version ,
115
106
)
116
107
raise ValueError (error_helper_string )
117
108
0 commit comments