15
15
import numpy
16
16
import os
17
17
import pytest
18
- from sagemaker . pytorch . defaults import LATEST_PY2_VERSION
18
+
19
19
from sagemaker .pytorch .estimator import PyTorch
20
20
from sagemaker .pytorch .model import PyTorchModel
21
21
from sagemaker .utils import sagemaker_timestamp
22
-
23
22
from tests .integ import (
24
23
test_region ,
25
24
DATA_DIR ,
26
- PYTHON_VERSION ,
27
25
TRAINING_DEFAULT_TIMEOUT_MINUTES ,
28
26
EI_SUPPORTED_REGIONS ,
29
27
)
39
37
40
38
41
39
@pytest .fixture (scope = "module" , name = "pytorch_training_job" )
42
- def fixture_training_job (sagemaker_session , pytorch_full_version , cpu_instance_type ):
40
+ def fixture_training_job (
41
+ sagemaker_session , pytorch_full_version , pytorch_full_py_version , cpu_instance_type
42
+ ):
43
43
with timeout (minutes = TRAINING_DEFAULT_TIMEOUT_MINUTES ):
44
- pytorch = _get_pytorch_estimator (sagemaker_session , pytorch_full_version , cpu_instance_type )
44
+ pytorch = _get_pytorch_estimator (
45
+ sagemaker_session , pytorch_full_version , pytorch_full_py_version , cpu_instance_type
46
+ )
45
47
46
48
pytorch .fit ({"training" : _upload_training_data (pytorch )})
47
49
return pytorch .latest_training_job .name
48
50
49
51
50
52
@pytest .mark .canary_quick
51
53
@pytest .mark .regional_testing
52
- @pytest .mark .skipif (
53
- PYTHON_VERSION == "py2" ,
54
- reason = "Python 2 is supported by PyTorch {} and lower versions." .format (LATEST_PY2_VERSION ),
55
- )
56
- def test_sync_fit_deploy (pytorch_training_job , sagemaker_session , cpu_instance_type ):
57
- # TODO: add tests against local mode when it's ready to be used
54
+ def test_fit_deploy (pytorch_training_job , sagemaker_session , cpu_instance_type ):
58
55
endpoint_name = "test-pytorch-sync-fit-attach-deploy{}" .format (sagemaker_timestamp ())
59
56
with timeout_and_delete_endpoint_by_name (endpoint_name , sagemaker_session ):
60
57
estimator = PyTorch .attach (pytorch_training_job , sagemaker_session = sagemaker_session )
@@ -70,16 +67,12 @@ def test_sync_fit_deploy(pytorch_training_job, sagemaker_session, cpu_instance_t
70
67
71
68
72
69
@pytest .mark .local_mode
73
- @pytest .mark .skipif (
74
- PYTHON_VERSION == "py2" ,
75
- reason = "Python 2 is supported by PyTorch {} and lower versions." .format (LATEST_PY2_VERSION ),
76
- )
77
- def test_fit_deploy (sagemaker_local_session , pytorch_full_version ):
70
+ def test_local_fit_deploy (sagemaker_local_session , pytorch_full_version , pytorch_full_py_version ):
78
71
pytorch = PyTorch (
79
72
entry_point = MNIST_SCRIPT ,
80
73
role = "SageMakerRole" ,
81
74
framework_version = pytorch_full_version ,
82
- py_version = "py3" ,
75
+ py_version = pytorch_full_py_version ,
83
76
train_instance_count = 1 ,
84
77
train_instance_type = "local" ,
85
78
sagemaker_session = sagemaker_local_session ,
@@ -99,7 +92,11 @@ def test_fit_deploy(sagemaker_local_session, pytorch_full_version):
99
92
100
93
101
94
def test_deploy_model (
102
- pytorch_training_job , sagemaker_session , cpu_instance_type , pytorch_full_version
95
+ pytorch_training_job ,
96
+ sagemaker_session ,
97
+ cpu_instance_type ,
98
+ pytorch_full_version ,
99
+ pytorch_full_py_version ,
103
100
):
104
101
endpoint_name = "test-pytorch-deploy-model-{}" .format (sagemaker_timestamp ())
105
102
@@ -113,7 +110,7 @@ def test_deploy_model(
113
110
"SageMakerRole" ,
114
111
entry_point = MNIST_SCRIPT ,
115
112
framework_version = pytorch_full_version ,
116
- py_version = "py3" ,
113
+ py_version = pytorch_full_py_version ,
117
114
sagemaker_session = sagemaker_session ,
118
115
)
119
116
predictor = model .deploy (1 , cpu_instance_type , endpoint_name = endpoint_name )
@@ -125,7 +122,9 @@ def test_deploy_model(
125
122
assert output .shape == (batch_size , 10 )
126
123
127
124
128
- def test_deploy_packed_model_with_entry_point_name (sagemaker_session , cpu_instance_type ):
125
+ def test_deploy_packed_model_with_entry_point_name (
126
+ sagemaker_session , cpu_instance_type , pytorch_full_version , pytorch_full_py_version
127
+ ):
129
128
endpoint_name = "test-pytorch-deploy-model-{}" .format (sagemaker_timestamp ())
130
129
131
130
with timeout_and_delete_endpoint_by_name (endpoint_name , sagemaker_session ):
@@ -134,8 +133,8 @@ def test_deploy_packed_model_with_entry_point_name(sagemaker_session, cpu_instan
134
133
model_data ,
135
134
"SageMakerRole" ,
136
135
entry_point = "mnist.py" ,
137
- framework_version = "1.4.0" ,
138
- py_version = "py3" ,
136
+ framework_version = pytorch_full_version ,
137
+ py_version = pytorch_full_py_version ,
139
138
sagemaker_session = sagemaker_session ,
140
139
)
141
140
predictor = model .deploy (1 , cpu_instance_type , endpoint_name = endpoint_name )
@@ -147,19 +146,20 @@ def test_deploy_packed_model_with_entry_point_name(sagemaker_session, cpu_instan
147
146
assert output .shape == (batch_size , 10 )
148
147
149
148
150
- @pytest .mark .skipif (PYTHON_VERSION == "py2" , reason = "PyTorch EIA does not support Python 2." )
151
149
@pytest .mark .skipif (
152
150
test_region () not in EI_SUPPORTED_REGIONS , reason = "EI isn't supported in that specific region."
153
151
)
154
- def test_deploy_model_with_accelerator (sagemaker_session , cpu_instance_type ):
152
+ def test_deploy_model_with_accelerator (
153
+ sagemaker_session , cpu_instance_type , pytorch_full_ei_version , pytorch_full_py_version
154
+ ):
155
155
endpoint_name = "test-pytorch-deploy-eia-{}" .format (sagemaker_timestamp ())
156
156
model_data = sagemaker_session .upload_data (path = EIA_MODEL )
157
157
pytorch = PyTorchModel (
158
158
model_data ,
159
159
"SageMakerRole" ,
160
160
entry_point = EIA_SCRIPT ,
161
- framework_version = "1.3.1" ,
162
- py_version = "py3" ,
161
+ framework_version = pytorch_full_ei_version ,
162
+ py_version = pytorch_full_py_version ,
163
163
sagemaker_session = sagemaker_session ,
164
164
)
165
165
with timeout_and_delete_endpoint_by_name (endpoint_name , sagemaker_session ):
@@ -185,13 +185,13 @@ def _upload_training_data(pytorch):
185
185
186
186
187
187
def _get_pytorch_estimator (
188
- sagemaker_session , pytorch_full_version , instance_type , entry_point = MNIST_SCRIPT
188
+ sagemaker_session , pytorch_version , py_version , instance_type , entry_point = MNIST_SCRIPT
189
189
):
190
190
return PyTorch (
191
191
entry_point = entry_point ,
192
192
role = "SageMakerRole" ,
193
- framework_version = pytorch_full_version ,
194
- py_version = "py3" ,
193
+ framework_version = pytorch_version ,
194
+ py_version = py_version ,
195
195
train_instance_count = 1 ,
196
196
train_instance_type = instance_type ,
197
197
sagemaker_session = sagemaker_session ,
0 commit comments