@@ -138,6 +138,42 @@ def tfs_predictor_with_accelerator(
138
138
yield predictor
139
139
140
140
141
+ @pytest .fixture (scope = "module" )
142
+ def tfs_trt_predictor_with_accelerator (
143
+ sagemaker_session , tensorflow_eia_latest_version , cpu_instance_type
144
+ ):
145
+ endpoint_name = sagemaker .utils .unique_name_from_base ("sagemaker-tensorflow-serving" )
146
+ model_data = sagemaker_session .upload_data (
147
+ path = os .path .join (tests .integ .DATA_DIR , "tensorflow-serving-test-model.tar.gz" ),
148
+ key_prefix = "tensorflow-serving/compiledmodels" ,
149
+ )
150
+ bucket = sagemaker_session .default_bucket ()
151
+ with tests .integ .timeout .timeout_and_delete_endpoint_by_name (endpoint_name , sagemaker_session ):
152
+ model = TensorFlowModel (
153
+ model_data = model_data ,
154
+ role = "SageMakerRole" ,
155
+ framework_version = tensorflow_eia_latest_version ,
156
+ sagemaker_session = sagemaker_session ,
157
+ name = endpoint_name ,
158
+ )
159
+ data_shape = {"input" : [1 , 224 , 224 , 3 ]}
160
+ tfs_eia_compilation_job_name = "tfs_eia_compilation_job_name"
161
+ compiled_model_path = "s3://{}/{}/output" .format (bucket , tfs_eia_compilation_job_name )
162
+ compiled_model = model .compile (
163
+ target_instance_family = 'ml_eia2' ,
164
+ input_shape = data_shape ,
165
+ output_path = compiled_model_path ,
166
+ role = "SageMakerRole" ,
167
+ job_name = tfs_eia_compilation_job_name ,
168
+ framework = 'tensorflow' ,
169
+ framework_version = '2.3'
170
+ )
171
+ predictor = compiled_model .deploy (
172
+ 1 , cpu_instance_type , endpoint_name = endpoint_name , accelerator_type = "ml.eia2.large"
173
+ )
174
+ yield predictor
175
+
176
+
141
177
@pytest .mark .release
142
178
def test_predict (tfs_predictor ):
143
179
input_data = {"instances" : [1.0 , 2.0 , 5.0 ]}
@@ -160,6 +196,23 @@ def test_predict_with_accelerator(tfs_predictor_with_accelerator):
160
196
assert expected_result == result
161
197
162
198
199
+ @pytest .mark .skipif (
200
+ tests .integ .test_region () not in tests .integ .EI_SUPPORTED_REGIONS ,
201
+ reason = "EI is not supported in region {}" .format (tests .integ .test_region ()),
202
+ )
203
+ @pytest .mark .release
204
+ def test_trt_predict_with_accelerator (tfs_predictor_with_accelerator ):
205
+ import numpy as np
206
+ import matplotlib .image as mpimg
207
+ path = os .path .join (tests .integ .DATA_DIR , "cuteCat.jpg" )
208
+ img = mpimg .imread (path )
209
+ img = np .resize (img , (224 , 224 , 3 ))
210
+ img = np .expand_dims (img , axis = 0 )
211
+ input_data = {"inputs" : img }
212
+ result = tfs_trt_predictor_with_accelerator .predict (input_data )
213
+ print ("trt predictor result is: " + result )
214
+
215
+
163
216
@pytest .mark .local_mode
164
217
def test_predict_with_entry_point (tfs_predictor_with_model_and_entry_point_same_tar ):
165
218
input_data = {"instances" : [1.0 , 2.0 , 5.0 ]}
0 commit comments