Skip to content

Commit 7ada426

Browse files
fix pytorch version fixture
1 parent 81976fe commit 7ada426

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

tests/conftest.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -567,11 +567,15 @@ def gpu_instance_type(sagemaker_session, request):
567567

568568
@pytest.fixture()
569569
def gpu_pytorch_instance_type(sagemaker_session, request):
570-
if "pytorch_inference_version" in request.fixturenames:
571-
fw_version = request.getfixturevalue("pytorch_inference_version")
572-
else:
570+
for pytorch_version_fixture in [
571+
"pytorch_inference_version",
572+
"huggingface_training_pytorch_latest_version",
573+
"huggingface_inference_pytorch_latest_version",
574+
]:
575+
if pytorch_version_fixture in request.fixturenames:
576+
fw_version = request.getfixturevalue(pytorch_version_fixture)
577+
if fw_version is None:
573578
fw_version = request.param
574-
575579
region = sagemaker_session.boto_session.region_name
576580
if region in NO_P3_REGIONS:
577581
if Version(fw_version) >= Version("1.13"):

0 commit comments

Comments
 (0)