17
17
import sagemaker
18
18
import sagemaker .predictor
19
19
import sagemaker .utils
20
+ import tests .integ
21
+ import tests .integ .timeout
20
22
from sagemaker .tensorflow .serving import Model , Predictor
21
- from tests .integ .timeout import timeout_and_delete_endpoint_by_name
22
23
23
24
24
25
@pytest .fixture (scope = 'session' , params = ['ml.c5.xlarge' , 'ml.p3.2xlarge' ])
@@ -32,17 +33,21 @@ def tfs_predictor(instance_type, sagemaker_session, tf_full_version):
32
33
model_data = sagemaker_session .upload_data (
33
34
path = 'tests/data/tensorflow-serving-test-model.tar.gz' ,
34
35
key_prefix = 'tensorflow-serving/models' )
35
- with timeout_and_delete_endpoint_by_name (endpoint_name , sagemaker_session ):
36
+ with tests . integ . timeout . timeout_and_delete_endpoint_by_name (endpoint_name , sagemaker_session ):
36
37
model = Model (model_data = model_data , role = 'SageMakerRole' ,
37
38
framework_version = tf_full_version ,
38
39
sagemaker_session = sagemaker_session )
39
40
predictor = model .deploy (1 , instance_type , endpoint_name = endpoint_name )
40
41
yield predictor
41
42
42
43
43
- # @pytest.mark.continuous_testing
44
- # @pytest.mark.regional_testing
45
- def test_predict (tfs_predictor ):
44
+ @pytest .mark .continuous_testing
45
+ @pytest .mark .regional_testing
46
+ def test_predict (tfs_predictor , instance_type ):
47
+ if ('p3' in instance_type ) and (
48
+ tests .integ .REGION in tests .integ .HOSTING_P3_UNAVAILABLE_REGIONS ):
49
+ pytest .skip ('no ml.p3 instances in this region' )
50
+
46
51
input_data = {'instances' : [1.0 , 2.0 , 5.0 ]}
47
52
expected_result = {'predictions' : [3.5 , 4.0 , 5.5 ]}
48
53
0 commit comments