Skip to content

Commit 58cd44b

Browse files
handle instance support for hf images
1 parent bebce21 commit 58cd44b

File tree

1 file changed

+8
-23
lines changed

1 file changed

+8
-23
lines changed

tests/integ/test_huggingface.py

Lines changed: 8 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -28,18 +28,13 @@
2828

2929

3030
@pytest.mark.release
31-
@pytest.mark.skipif(
32-
integ.test_region() in integ.TRAINING_NO_P2_REGIONS
33-
and integ.test_region() in integ.TRAINING_NO_P3_REGIONS,
34-
reason="no ml.p2 or ml.p3 instances in this region",
35-
)
3631
@retry_with_instance_list(gpu_list(integ.test_region()))
3732
def test_framework_processing_job_with_deps(
3833
sagemaker_session,
3934
huggingface_training_latest_version,
4035
huggingface_training_pytorch_latest_version,
4136
huggingface_pytorch_latest_training_py_version,
42-
**kwargs,
37+
gpu_pytorch_instance_type,
4338
):
4439
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
4540
code_path = os.path.join(DATA_DIR, "dummy_code_bundle_with_reqs")
@@ -51,7 +46,7 @@ def test_framework_processing_job_with_deps(
5146
py_version=huggingface_pytorch_latest_training_py_version,
5247
role=ROLE,
5348
instance_count=1,
54-
instance_type=kwargs["instance_type"],
49+
instance_type=gpu_pytorch_instance_type,
5550
sagemaker_session=sagemaker_session,
5651
base_job_name="test-huggingface",
5752
)
@@ -64,18 +59,13 @@ def test_framework_processing_job_with_deps(
6459

6560

6661
@pytest.mark.release
67-
@pytest.mark.skipif(
68-
integ.test_region() in integ.TRAINING_NO_P2_REGIONS
69-
and integ.test_region() in integ.TRAINING_NO_P3_REGIONS,
70-
reason="no ml.p2 or ml.p3 instances in this region",
71-
)
7262
@retry_with_instance_list(gpu_list(integ.test_region()))
7363
def test_huggingface_training(
7464
sagemaker_session,
7565
huggingface_training_latest_version,
7666
huggingface_training_pytorch_latest_version,
7767
huggingface_pytorch_latest_training_py_version,
78-
**kwargs,
68+
gpu_pytorch_instance_type,
7969
):
8070
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
8171
data_path = os.path.join(DATA_DIR, "huggingface")
@@ -87,7 +77,7 @@ def test_huggingface_training(
8777
transformers_version=huggingface_training_latest_version,
8878
pytorch_version=huggingface_training_pytorch_latest_version,
8979
instance_count=1,
90-
instance_type=kwargs["instance_type"],
80+
instance_type=gpu_pytorch_instance_type,
9181
hyperparameters={
9282
"model_name_or_path": "distilbert-base-cased",
9383
"task_name": "wnli",
@@ -111,17 +101,12 @@ def test_huggingface_training(
111101

112102

113103
@pytest.mark.release
114-
@pytest.mark.skipif(
115-
integ.test_region() in integ.TRAINING_NO_P2_REGIONS
116-
and integ.test_region() in integ.TRAINING_NO_P3_REGIONS,
117-
reason="no ml.p2 or ml.p3 instances in this region",
118-
)
119104
@pytest.mark.skip(
120105
reason="need to re enable it later t.corp:V609860141",
121106
)
122107
def test_huggingface_training_tf(
123108
sagemaker_session,
124-
gpu_instance_type,
109+
gpu_pytorch_instance_type,
125110
huggingface_training_latest_version,
126111
huggingface_training_tensorflow_latest_version,
127112
huggingface_tensorflow_latest_training_py_version,
@@ -136,7 +121,7 @@ def test_huggingface_training_tf(
136121
transformers_version=huggingface_training_latest_version,
137122
tensorflow_version=huggingface_training_tensorflow_latest_version,
138123
instance_count=1,
139-
instance_type=gpu_instance_type,
124+
instance_type=gpu_pytorch_instance_type,
140125
hyperparameters={
141126
"model_name_or_path": "distilbert-base-cased",
142127
"per_device_train_batch_size": 128,
@@ -161,7 +146,7 @@ def test_huggingface_training_tf(
161146
)
162147
def test_huggingface_inference(
163148
sagemaker_session,
164-
gpu_instance_type,
149+
gpu_pytorch_instance_type,
165150
huggingface_inference_latest_version,
166151
huggingface_inference_pytorch_latest_version,
167152
huggingface_pytorch_latest_inference_py_version,
@@ -182,7 +167,7 @@ def test_huggingface_inference(
182167
)
183168
with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
184169
model.deploy(
185-
instance_type=gpu_instance_type, initial_instance_count=1, endpoint_name=endpoint_name
170+
instance_type=gpu_pytorch_instance_type, initial_instance_count=1, endpoint_name=endpoint_name
186171
)
187172

188173
predictor = HuggingFacePredictor(endpoint_name=endpoint_name)

0 commit comments

Comments
 (0)