Skip to content

Commit 5b2daa4

Browse files
edit unit test for trcomp support version coverage
1 parent 32b9c79 commit 5b2daa4

File tree

2 files changed

+33
-5
lines changed

2 files changed

+33
-5
lines changed

tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def test_unsupported_gpu_instance(
218218
).fit()
219219

220220

221-
def test_unsupported_framework_version(
221+
def test_unsupported_framework_version_min(
222222
huggingface_training_compiler_version,
223223
):
224224
with pytest.raises(ValueError):
@@ -229,9 +229,24 @@ def test_unsupported_framework_version(
229229
instance_count=INSTANCE_COUNT,
230230
instance_type=INSTANCE_TYPE,
231231
transformers_version=huggingface_training_compiler_version,
232-
pytorch_version=".".join(
233-
["99"] * len(huggingface_training_compiler_version.split("."))
234-
),
232+
pytorch_version="1.8",
233+
enable_sagemaker_metrics=False,
234+
compiler_config=TrainingCompilerConfig(),
235+
).fit()
236+
237+
238+
def test_unsupported_framework_version_max(
239+
huggingface_training_compiler_version,
240+
):
241+
with pytest.raises(ValueError):
242+
HuggingFace(
243+
py_version="py38",
244+
entry_point=SCRIPT_PATH,
245+
role=ROLE,
246+
instance_count=INSTANCE_COUNT,
247+
instance_type=INSTANCE_TYPE,
248+
transformers_version=huggingface_training_compiler_version,
249+
pytorch_version="1.12",
235250
enable_sagemaker_metrics=False,
236251
compiler_config=TrainingCompilerConfig(),
237252
).fit()

tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def test_gpu_instance(
189189
compiler_config=TrainingCompilerConfig(),
190190
).fit()
191191

192-
def test_framework_version(self, tensorflow_training_py_version):
192+
def test_framework_version_min(self, tensorflow_training_py_version):
193193
with pytest.raises(ValueError):
194194
TensorFlow(
195195
py_version=tensorflow_training_py_version,
@@ -202,6 +202,19 @@ def test_framework_version(self, tensorflow_training_py_version):
202202
compiler_config=TrainingCompilerConfig(),
203203
).fit()
204204

205+
def test_framework_version_max(self, tensorflow_training_py_version):
206+
with pytest.raises(ValueError):
207+
TensorFlow(
208+
py_version=tensorflow_training_py_version,
209+
entry_point=SCRIPT_PATH,
210+
role=ROLE,
211+
instance_count=INSTANCE_COUNT,
212+
instance_type=INSTANCE_TYPE,
213+
framework_version="2.12",
214+
enable_sagemaker_metrics=False,
215+
compiler_config=TrainingCompilerConfig(),
216+
).fit()
217+
205218
def test_mwms(self, tensorflow_training_version, tensorflow_training_py_version):
206219
with pytest.raises(ValueError):
207220
TensorFlow(

0 commit comments

Comments
 (0)