19
19
from sagemaker .huggingface import HuggingFace , HuggingFaceProcessor
20
20
from sagemaker .huggingface .model import HuggingFaceModel , HuggingFacePredictor
21
21
from sagemaker .utils import unique_name_from_base
22
- from tests import integ
23
- from tests .integ .utils import gpu_list , retry_with_instance_list
24
22
from tests .integ import DATA_DIR , TRAINING_DEFAULT_TIMEOUT_MINUTES
25
23
from tests .integ .timeout import timeout , timeout_and_delete_endpoint_by_name
26
24
27
25
ROLE = "SageMakerRole"
28
26
29
27
30
28
@pytest .mark .release
31
- @pytest .mark .skipif (
32
- integ .test_region () in integ .TRAINING_NO_P2_REGIONS
33
- and integ .test_region () in integ .TRAINING_NO_P3_REGIONS ,
34
- reason = "no ml.p2 or ml.p3 instances in this region" ,
35
- )
36
- @retry_with_instance_list (gpu_list (integ .test_region ()))
37
29
def test_framework_processing_job_with_deps (
38
30
sagemaker_session ,
39
31
huggingface_training_latest_version ,
40
32
huggingface_training_pytorch_latest_version ,
41
33
huggingface_pytorch_latest_training_py_version ,
42
- ** kwargs ,
34
+ gpu_pytorch_instance_type ,
43
35
):
44
36
with timeout (minutes = TRAINING_DEFAULT_TIMEOUT_MINUTES ):
45
37
code_path = os .path .join (DATA_DIR , "dummy_code_bundle_with_reqs" )
@@ -51,7 +43,7 @@ def test_framework_processing_job_with_deps(
51
43
py_version = huggingface_pytorch_latest_training_py_version ,
52
44
role = ROLE ,
53
45
instance_count = 1 ,
54
- instance_type = kwargs [ "instance_type" ] ,
46
+ instance_type = gpu_pytorch_instance_type ,
55
47
sagemaker_session = sagemaker_session ,
56
48
base_job_name = "test-huggingface" ,
57
49
)
@@ -64,18 +56,12 @@ def test_framework_processing_job_with_deps(
64
56
65
57
66
58
@pytest .mark .release
67
- @pytest .mark .skipif (
68
- integ .test_region () in integ .TRAINING_NO_P2_REGIONS
69
- and integ .test_region () in integ .TRAINING_NO_P3_REGIONS ,
70
- reason = "no ml.p2 or ml.p3 instances in this region" ,
71
- )
72
- @retry_with_instance_list (gpu_list (integ .test_region ()))
73
59
def test_huggingface_training (
74
60
sagemaker_session ,
75
61
huggingface_training_latest_version ,
76
62
huggingface_training_pytorch_latest_version ,
77
63
huggingface_pytorch_latest_training_py_version ,
78
- ** kwargs ,
64
+ gpu_pytorch_instance_type ,
79
65
):
80
66
with timeout (minutes = TRAINING_DEFAULT_TIMEOUT_MINUTES ):
81
67
data_path = os .path .join (DATA_DIR , "huggingface" )
@@ -87,7 +73,7 @@ def test_huggingface_training(
87
73
transformers_version = huggingface_training_latest_version ,
88
74
pytorch_version = huggingface_training_pytorch_latest_version ,
89
75
instance_count = 1 ,
90
- instance_type = kwargs [ "instance_type" ] ,
76
+ instance_type = gpu_pytorch_instance_type ,
91
77
hyperparameters = {
92
78
"model_name_or_path" : "distilbert-base-cased" ,
93
79
"task_name" : "wnli" ,
@@ -111,14 +97,6 @@ def test_huggingface_training(
111
97
112
98
113
99
@pytest .mark .release
114
- @pytest .mark .skipif (
115
- integ .test_region () in integ .TRAINING_NO_P2_REGIONS
116
- and integ .test_region () in integ .TRAINING_NO_P3_REGIONS ,
117
- reason = "no ml.p2 or ml.p3 instances in this region" ,
118
- )
119
- @pytest .mark .skip (
120
- reason = "need to re enable it later t.corp:V609860141" ,
121
- )
122
100
def test_huggingface_training_tf (
123
101
sagemaker_session ,
124
102
gpu_instance_type ,
@@ -161,7 +139,7 @@ def test_huggingface_training_tf(
161
139
)
162
140
def test_huggingface_inference (
163
141
sagemaker_session ,
164
- gpu_instance_type ,
142
+ gpu_pytorch_instance_type ,
165
143
huggingface_inference_latest_version ,
166
144
huggingface_inference_pytorch_latest_version ,
167
145
huggingface_pytorch_latest_inference_py_version ,
@@ -182,7 +160,9 @@ def test_huggingface_inference(
182
160
)
183
161
with timeout_and_delete_endpoint_by_name (endpoint_name , sagemaker_session ):
184
162
model .deploy (
185
- instance_type = gpu_instance_type , initial_instance_count = 1 , endpoint_name = endpoint_name
163
+ instance_type = gpu_pytorch_instance_type ,
164
+ initial_instance_count = 1 ,
165
+ endpoint_name = endpoint_name ,
186
166
)
187
167
188
168
predictor = HuggingFacePredictor (endpoint_name = endpoint_name )
0 commit comments