24
24
from sagemaker .utils import unique_name_from_base
25
25
26
26
import tests .integ
27
+ from tests .integ import timeout
27
28
28
29
ROLE = 'SageMakerRole'
29
30
30
- RESOURCE_PATH = os .path .join (os .path .dirname (__file__ ), '..' , 'data' , 'tensorflow_mnist' )
31
- SCRIPT = os .path .join (RESOURCE_PATH , 'mnist.py' )
31
+ RESOURCE_PATH = os .path .join (os .path .dirname (__file__ ), '..' , 'data' )
32
+ MNIST_RESOURCE_PATH = os .path .join (RESOURCE_PATH , 'tensorflow_mnist' )
33
+ TFS_RESOURCE_PATH = os .path .join (RESOURCE_PATH , 'tfs' , 'tfs-test-entrypoint-with-handler' )
34
+
35
+ SCRIPT = os .path .join (MNIST_RESOURCE_PATH , 'mnist.py' )
32
36
PARAMETER_SERVER_DISTRIBUTION = {'parameter_server' : {'enabled' : True }}
33
37
MPI_DISTRIBUTION = {'mpi' : {'enabled' : True }}
34
38
TAGS = [{'Key' : 'some-key' , 'Value' : 'some-value' }]
@@ -57,7 +61,7 @@ def test_mnist(sagemaker_session, instance_type):
57
61
metric_definitions = [
58
62
{'Name' : 'train:global_steps' , 'Regex' : r'global_step\/sec:\s(.*)' }])
59
63
inputs = estimator .sagemaker_session .upload_data (
60
- path = os .path .join (RESOURCE_PATH , 'data' ),
64
+ path = os .path .join (MNIST_RESOURCE_PATH , 'data' ),
61
65
key_prefix = 'scriptmode/mnist' )
62
66
63
67
with tests .integ .timeout .timeout (minutes = tests .integ .TRAINING_DEFAULT_TIMEOUT_MINUTES ):
@@ -88,7 +92,7 @@ def test_server_side_encryption(sagemaker_session):
88
92
output_kms_key = kms_key )
89
93
90
94
inputs = estimator .sagemaker_session .upload_data (
91
- path = os .path .join (RESOURCE_PATH , 'data' ),
95
+ path = os .path .join (MNIST_RESOURCE_PATH , 'data' ),
92
96
key_prefix = 'scriptmode/mnist' )
93
97
94
98
with tests .integ .timeout .timeout (minutes = tests .integ .TRAINING_DEFAULT_TIMEOUT_MINUTES ):
@@ -110,7 +114,7 @@ def test_mnist_distributed(sagemaker_session, instance_type):
110
114
framework_version = TensorFlow .LATEST_VERSION ,
111
115
distributions = PARAMETER_SERVER_DISTRIBUTION )
112
116
inputs = estimator .sagemaker_session .upload_data (
113
- path = os .path .join (RESOURCE_PATH , 'data' ),
117
+ path = os .path .join (MNIST_RESOURCE_PATH , 'data' ),
114
118
key_prefix = 'scriptmode/distributed_mnist' )
115
119
116
120
with tests .integ .timeout .timeout (minutes = tests .integ .TRAINING_DEFAULT_TIMEOUT_MINUTES ):
@@ -129,7 +133,7 @@ def test_mnist_async(sagemaker_session):
129
133
framework_version = TensorFlow .LATEST_VERSION ,
130
134
tags = TAGS )
131
135
inputs = estimator .sagemaker_session .upload_data (
132
- path = os .path .join (RESOURCE_PATH , 'data' ),
136
+ path = os .path .join (MNIST_RESOURCE_PATH , 'data' ),
133
137
key_prefix = 'scriptmode/mnist' )
134
138
estimator .fit (inputs = inputs , wait = False , job_name = unique_name_from_base ('test-tf-sm-async' ))
135
139
training_job_name = estimator .latest_training_job .name
@@ -150,6 +154,35 @@ def test_mnist_async(sagemaker_session):
150
154
estimator .latest_training_job .name , TAGS )
151
155
152
156
157
+ @pytest .mark .skipif (tests .integ .PYTHON_VERSION != 'py3' ,
158
+ reason = "Script Mode tests are only configured to run with Python 3" )
159
+ def test_deploy_with_input_handlers (sagemaker_session , instance_type ):
160
+ estimator = TensorFlow (entry_point = 'inference.py' ,
161
+ source_dir = TFS_RESOURCE_PATH ,
162
+ role = ROLE ,
163
+ train_instance_count = 1 ,
164
+ train_instance_type = instance_type ,
165
+ sagemaker_session = sagemaker_session ,
166
+ py_version = 'py3' ,
167
+ framework_version = TensorFlow .LATEST_VERSION ,
168
+ tags = TAGS )
169
+
170
+ estimator .fit (job_name = unique_name_from_base ('test-tf-tfs-deploy' ))
171
+
172
+ endpoint_name = estimator .latest_training_job .name
173
+
174
+ with timeout .timeout_and_delete_endpoint_by_name (endpoint_name , sagemaker_session ):
175
+
176
+ predictor = estimator .deploy (initial_instance_count = 1 , instance_type = instance_type ,
177
+ endpoint_name = endpoint_name )
178
+
179
+ input_data = {'instances' : [1.0 , 2.0 , 5.0 ]}
180
+ expected_result = {'predictions' : [4.0 , 4.5 , 6.0 ]}
181
+
182
+ result = predictor .predict (input_data )
183
+ assert expected_result == result
184
+
185
+
153
186
def _assert_s3_files_exist (s3_url , files ):
154
187
parsed_url = urlparse (s3_url )
155
188
s3 = boto3 .client ('s3' )
0 commit comments