Skip to content

Commit 42a2929

Browse files
icywang86ruiRui Wang Napieralskiahsan-z-khan
authored
fix: add version length mismatch validation for HuggingFace (#2266)
* fix: add version length mismatch validation for HuggingFace * modify huggingface_pytorch_version fixture Co-authored-by: Rui Wang Napieralski <[email protected]> Co-authored-by: Ahsan Khan <[email protected]>
1 parent 709bb18 commit 42a2929

File tree

4 files changed

+90
-6
lines changed

4 files changed

+90
-6
lines changed

src/sagemaker/huggingface/estimator.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@ def _validate_args(self, image_uri):
189189
"""Placeholder docstring"""
190190
if image_uri is not None:
191191
return
192+
192193
if self.framework_version is None and image_uri is None:
193194
raise ValueError(
194195
"transformers_version, and image_uri are both None. "
@@ -204,6 +205,17 @@ def _validate_args(self, image_uri):
204205
"tensorflow_version and pytorch_version are both None. "
205206
"Specify either tensorflow_version or pytorch_version."
206207
)
208+
base_framework_version_len = (
209+
len(self.tensorflow_version.split("."))
210+
if self.tensorflow_version is not None
211+
else len(self.pytorch_version.split("."))
212+
)
213+
transformers_version_len = len(self.framework_version.split("."))
214+
if transformers_version_len != base_framework_version_len:
215+
raise ValueError(
216+
"Please use either full version or shortened version for both "
217+
"transformers_version, tensorflow_version and pytorch_version."
218+
)
207219

208220
def hyperparameters(self):
209221
"""Return hyperparameters used by your custom PyTorch code during model training."""

tests/conftest.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -190,12 +190,31 @@ def pytorch_inference_py_version(pytorch_inference_version, request):
190190
return "py3"
191191

192192

193+
def _huggingface_pytorch_version(huggingface_vesion):
194+
config = image_uris.config_for_framework("huggingface")
195+
training_config = config.get("training")
196+
original_version = huggingface_vesion
197+
if "version_aliases" in training_config:
198+
huggingface_vesion = training_config.get("version_aliases").get(
199+
huggingface_vesion, huggingface_vesion
200+
)
201+
version_config = training_config.get("versions").get(huggingface_vesion)
202+
for key in list(version_config.keys()):
203+
if key.startswith("pytorch"):
204+
pt_version = key[7:]
205+
if len(original_version.split(".")) == 2:
206+
pt_version = ".".join(pt_version.split(".")[:-1])
207+
return pt_version
208+
209+
193210
@pytest.fixture(scope="module")
194211
def huggingface_pytorch_version(huggingface_training_version):
195-
if Version(huggingface_training_version) <= Version("4.4.2"):
196-
return "1.6.0"
197-
else:
198-
pytest.skip("Skipping Huggingface version.")
212+
return _huggingface_pytorch_version(huggingface_training_version)
213+
214+
215+
@pytest.fixture(scope="module")
216+
def huggingface_pytorch_latest_version(huggingface_training_latest_version):
217+
return _huggingface_pytorch_version(huggingface_training_latest_version)
199218

200219

201220
@pytest.fixture(scope="module")

tests/integ/test_huggingface.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def test_huggingface_training(
3030
sagemaker_session,
3131
gpu_instance_type,
3232
huggingface_training_latest_version,
33-
huggingface_pytorch_version,
33+
huggingface_pytorch_latest_version,
3434
):
3535
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
3636
data_path = os.path.join(DATA_DIR, "huggingface")
@@ -40,7 +40,7 @@ def test_huggingface_training(
4040
entry_point="examples/text-classification/run_glue.py",
4141
role="SageMakerRole",
4242
transformers_version=huggingface_training_latest_version,
43-
pytorch_version=huggingface_pytorch_version,
43+
pytorch_version=huggingface_pytorch_latest_version,
4444
instance_count=1,
4545
instance_type=gpu_instance_type,
4646
hyperparameters={

tests/unit/sagemaker/huggingface/test_estimator.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,59 @@ def _create_train_job(version, base_framework_version):
168168
}
169169

170170

171+
def test_huggingface_invalid_args():
172+
with pytest.raises(ValueError) as error:
173+
HuggingFace(
174+
py_version="py36",
175+
entry_point=SCRIPT_PATH,
176+
role=ROLE,
177+
instance_count=INSTANCE_COUNT,
178+
instance_type=INSTANCE_TYPE,
179+
transformers_version="4.2.1",
180+
pytorch_version="1.6",
181+
enable_sagemaker_metrics=False,
182+
)
183+
assert "use either full version or shortened version" in str(error)
184+
185+
with pytest.raises(ValueError) as error:
186+
HuggingFace(
187+
py_version="py36",
188+
entry_point=SCRIPT_PATH,
189+
role=ROLE,
190+
instance_count=INSTANCE_COUNT,
191+
instance_type=INSTANCE_TYPE,
192+
pytorch_version="1.6",
193+
enable_sagemaker_metrics=False,
194+
)
195+
assert "transformers_version, and image_uri are both None." in str(error)
196+
197+
with pytest.raises(ValueError) as error:
198+
HuggingFace(
199+
py_version="py36",
200+
entry_point=SCRIPT_PATH,
201+
role=ROLE,
202+
instance_count=INSTANCE_COUNT,
203+
instance_type=INSTANCE_TYPE,
204+
transformers_version="4.2.1",
205+
enable_sagemaker_metrics=False,
206+
)
207+
assert "tensorflow_version and pytorch_version are both None." in str(error)
208+
209+
with pytest.raises(ValueError) as error:
210+
HuggingFace(
211+
py_version="py36",
212+
entry_point=SCRIPT_PATH,
213+
role=ROLE,
214+
instance_count=INSTANCE_COUNT,
215+
instance_type=INSTANCE_TYPE,
216+
transformers_version="4.2",
217+
pytorch_version="1.6",
218+
tensorflow_version="2.3",
219+
enable_sagemaker_metrics=False,
220+
)
221+
assert "tensorflow_version and pytorch_version are both not None." in str(error)
222+
223+
171224
@patch("sagemaker.utils.repack_model", MagicMock())
172225
@patch("sagemaker.utils.create_tar_file", MagicMock())
173226
@patch("sagemaker.estimator.name_from_base", return_value=JOB_NAME)

0 commit comments

Comments
 (0)