Skip to content

Commit 8d282c1

Browse files
tejaschumbalkarNivekNeyandre-marcos-perezmufaddal-rohawala
authored
fix: Handle instance support for Hugging Face tests (#3729)
Co-authored-by: Kevin <[email protected]> Co-authored-by: André Perez <[email protected]> Co-authored-by: Mufaddal Rohawala <[email protected]>
1 parent e892400 commit 8d282c1

File tree

2 files changed

+17
-32
lines changed

2 files changed

+17
-32
lines changed

tests/conftest.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -567,11 +567,16 @@ def gpu_instance_type(sagemaker_session, request):
567567

568568
@pytest.fixture()
569569
def gpu_pytorch_instance_type(sagemaker_session, request):
570-
if "pytorch_inference_version" in request.fixturenames:
571-
fw_version = request.getfixturevalue("pytorch_inference_version")
572-
else:
570+
fw_version = None
571+
for pytorch_version_fixture in [
572+
"pytorch_inference_version",
573+
"huggingface_training_pytorch_latest_version",
574+
"huggingface_inference_pytorch_latest_version",
575+
]:
576+
if pytorch_version_fixture in request.fixturenames:
577+
fw_version = request.getfixturevalue(pytorch_version_fixture)
578+
if fw_version is None:
573579
fw_version = request.param
574-
575580
region = sagemaker_session.boto_session.region_name
576581
if region in NO_P3_REGIONS:
577582
if Version(fw_version) >= Version("1.13"):

tests/integ/test_huggingface.py

Lines changed: 8 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -19,27 +19,19 @@
1919
from sagemaker.huggingface import HuggingFace, HuggingFaceProcessor
2020
from sagemaker.huggingface.model import HuggingFaceModel, HuggingFacePredictor
2121
from sagemaker.utils import unique_name_from_base
22-
from tests import integ
23-
from tests.integ.utils import gpu_list, retry_with_instance_list
2422
from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES
2523
from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name
2624

2725
ROLE = "SageMakerRole"
2826

2927

3028
@pytest.mark.release
31-
@pytest.mark.skipif(
32-
integ.test_region() in integ.TRAINING_NO_P2_REGIONS
33-
and integ.test_region() in integ.TRAINING_NO_P3_REGIONS,
34-
reason="no ml.p2 or ml.p3 instances in this region",
35-
)
36-
@retry_with_instance_list(gpu_list(integ.test_region()))
3729
def test_framework_processing_job_with_deps(
3830
sagemaker_session,
3931
huggingface_training_latest_version,
4032
huggingface_training_pytorch_latest_version,
4133
huggingface_pytorch_latest_training_py_version,
42-
**kwargs,
34+
gpu_pytorch_instance_type,
4335
):
4436
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
4537
code_path = os.path.join(DATA_DIR, "dummy_code_bundle_with_reqs")
@@ -51,7 +43,7 @@ def test_framework_processing_job_with_deps(
5143
py_version=huggingface_pytorch_latest_training_py_version,
5244
role=ROLE,
5345
instance_count=1,
54-
instance_type=kwargs["instance_type"],
46+
instance_type=gpu_pytorch_instance_type,
5547
sagemaker_session=sagemaker_session,
5648
base_job_name="test-huggingface",
5749
)
@@ -64,18 +56,12 @@ def test_framework_processing_job_with_deps(
6456

6557

6658
@pytest.mark.release
67-
@pytest.mark.skipif(
68-
integ.test_region() in integ.TRAINING_NO_P2_REGIONS
69-
and integ.test_region() in integ.TRAINING_NO_P3_REGIONS,
70-
reason="no ml.p2 or ml.p3 instances in this region",
71-
)
72-
@retry_with_instance_list(gpu_list(integ.test_region()))
7359
def test_huggingface_training(
7460
sagemaker_session,
7561
huggingface_training_latest_version,
7662
huggingface_training_pytorch_latest_version,
7763
huggingface_pytorch_latest_training_py_version,
78-
**kwargs,
64+
gpu_pytorch_instance_type,
7965
):
8066
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
8167
data_path = os.path.join(DATA_DIR, "huggingface")
@@ -87,7 +73,7 @@ def test_huggingface_training(
8773
transformers_version=huggingface_training_latest_version,
8874
pytorch_version=huggingface_training_pytorch_latest_version,
8975
instance_count=1,
90-
instance_type=kwargs["instance_type"],
76+
instance_type=gpu_pytorch_instance_type,
9177
hyperparameters={
9278
"model_name_or_path": "distilbert-base-cased",
9379
"task_name": "wnli",
@@ -111,14 +97,6 @@ def test_huggingface_training(
11197

11298

11399
@pytest.mark.release
114-
@pytest.mark.skipif(
115-
integ.test_region() in integ.TRAINING_NO_P2_REGIONS
116-
and integ.test_region() in integ.TRAINING_NO_P3_REGIONS,
117-
reason="no ml.p2 or ml.p3 instances in this region",
118-
)
119-
@pytest.mark.skip(
120-
reason="need to re enable it later t.corp:V609860141",
121-
)
122100
def test_huggingface_training_tf(
123101
sagemaker_session,
124102
gpu_instance_type,
@@ -161,7 +139,7 @@ def test_huggingface_training_tf(
161139
)
162140
def test_huggingface_inference(
163141
sagemaker_session,
164-
gpu_instance_type,
142+
gpu_pytorch_instance_type,
165143
huggingface_inference_latest_version,
166144
huggingface_inference_pytorch_latest_version,
167145
huggingface_pytorch_latest_inference_py_version,
@@ -182,7 +160,9 @@ def test_huggingface_inference(
182160
)
183161
with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
184162
model.deploy(
185-
instance_type=gpu_instance_type, initial_instance_count=1, endpoint_name=endpoint_name
163+
instance_type=gpu_pytorch_instance_type,
164+
initial_instance_count=1,
165+
endpoint_name=endpoint_name,
186166
)
187167

188168
predictor = HuggingFacePredictor(endpoint_name=endpoint_name)

0 commit comments

Comments
 (0)