Skip to content

Commit aa246ad

Browse files
author
Rui Wang Napieralski
committed
fix: add version length mismatch validation for HuggingFace
1 parent b66cb98 commit aa246ad

File tree

3 files changed

+69
-1
lines changed

3 files changed

+69
-1
lines changed

src/sagemaker/huggingface/estimator.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@ def _validate_args(self, image_uri):
189189
"""Placeholder docstring"""
190190
if image_uri is not None:
191191
return
192+
192193
if self.framework_version is None and image_uri is None:
193194
raise ValueError(
194195
"transformers_version, and image_uri are both None. "
@@ -204,6 +205,17 @@ def _validate_args(self, image_uri):
204205
"tensorflow_version and pytorch_version are both None. "
205206
"Specify either tensorflow_version or pytorch_version."
206207
)
208+
base_framework_version_len = (
209+
len(self.tensorflow_version.split("."))
210+
if self.tensorflow_version is not None
211+
else len(self.pytorch_version.split("."))
212+
)
213+
transformers_version_len = len(self.framework_version.split("."))
214+
if transformers_version_len != base_framework_version_len:
215+
raise ValueError(
216+
"Please use either full version or shortened version for both "
217+
"transformers_version, tensorflow_version and pytorch_version."
218+
)
207219

208220
def hyperparameters(self):
209221
"""Return hyperparameters used by your custom PyTorch code during model training."""

tests/conftest.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,10 @@ def pytorch_inference_py_version(pytorch_inference_version, request):
193193
@pytest.fixture(scope="module")
194194
def huggingface_pytorch_version(huggingface_training_version):
195195
if Version(huggingface_training_version) <= Version("4.4.2"):
196-
return "1.6.0"
196+
if len(huggingface_training_version.split(".")) == 3:
197+
return "1.6.0"
198+
else:
199+
return "1.6"
197200
else:
198201
pytest.skip("Skipping Huggingface version.")
199202

tests/unit/sagemaker/huggingface/test_estimator.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,59 @@ def _create_train_job(version, base_framework_version):
168168
}
169169

170170

171+
def test_huggingface_invalid_args():
172+
with pytest.raises(ValueError) as error:
173+
HuggingFace(
174+
py_version="py36",
175+
entry_point=SCRIPT_PATH,
176+
role=ROLE,
177+
instance_count=INSTANCE_COUNT,
178+
instance_type=INSTANCE_TYPE,
179+
transformers_version="4.2.1",
180+
pytorch_version="1.6",
181+
enable_sagemaker_metrics=False,
182+
)
183+
assert "use either full version or shortened version" in str(error)
184+
185+
with pytest.raises(ValueError) as error:
186+
HuggingFace(
187+
py_version="py36",
188+
entry_point=SCRIPT_PATH,
189+
role=ROLE,
190+
instance_count=INSTANCE_COUNT,
191+
instance_type=INSTANCE_TYPE,
192+
pytorch_version="1.6",
193+
enable_sagemaker_metrics=False,
194+
)
195+
assert "transformers_version, and image_uri are both None." in str(error)
196+
197+
with pytest.raises(ValueError) as error:
198+
HuggingFace(
199+
py_version="py36",
200+
entry_point=SCRIPT_PATH,
201+
role=ROLE,
202+
instance_count=INSTANCE_COUNT,
203+
instance_type=INSTANCE_TYPE,
204+
transformers_version="4.2.1",
205+
enable_sagemaker_metrics=False,
206+
)
207+
assert "tensorflow_version and pytorch_version are both None." in str(error)
208+
209+
with pytest.raises(ValueError) as error:
210+
HuggingFace(
211+
py_version="py36",
212+
entry_point=SCRIPT_PATH,
213+
role=ROLE,
214+
instance_count=INSTANCE_COUNT,
215+
instance_type=INSTANCE_TYPE,
216+
transformers_version="4.2",
217+
pytorch_version="1.6",
218+
tensorflow_version="2.3",
219+
enable_sagemaker_metrics=False,
220+
)
221+
assert "tensorflow_version and pytorch_version are both not None." in str(error)
222+
223+
171224
@patch("sagemaker.utils.repack_model", MagicMock())
172225
@patch("sagemaker.utils.create_tar_file", MagicMock())
173226
@patch("sagemaker.estimator.name_from_base", return_value=JOB_NAME)

0 commit comments

Comments
 (0)