Skip to content

Commit 40c9d2c

Browse files
authored
add canary tests for tf serving container (#478)
* add canary tests for tf serving container
1 parent 49912ff commit 40c9d2c

File tree

2 files changed

+14
-5
lines changed

2 files changed

+14
-5
lines changed

tests/integ/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,9 @@
2525
PYTHON_VERSION = 'py' + str(sys.version_info.major)
2626
REGION = boto3.session.Session().region_name
2727

28+
HOSTING_P2_UNAVAILABLE_REGIONS = ['ca-central-1', 'us-west-1', 'eu-west-2']
29+
HOSTING_P3_UNAVAILABLE_REGIONS = ['ap-southeast-1', 'ap-southeast-2', 'ap-south-1', 'ca-central-1',
30+
'us-west-1']
31+
2832
logging.getLogger('boto3').setLevel(logging.INFO)
2933
logging.getLogger('botocore').setLevel(logging.INFO)

tests/integ/test_tfs.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@
1717
import sagemaker
1818
import sagemaker.predictor
1919
import sagemaker.utils
20+
import tests.integ
21+
import tests.integ.timeout
2022
from sagemaker.tensorflow.serving import Model, Predictor
21-
from tests.integ.timeout import timeout_and_delete_endpoint_by_name
2223

2324

2425
@pytest.fixture(scope='session', params=['ml.c5.xlarge', 'ml.p3.2xlarge'])
@@ -32,17 +33,21 @@ def tfs_predictor(instance_type, sagemaker_session, tf_full_version):
3233
model_data = sagemaker_session.upload_data(
3334
path='tests/data/tensorflow-serving-test-model.tar.gz',
3435
key_prefix='tensorflow-serving/models')
35-
with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
36+
with tests.integ.timeout.timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
3637
model = Model(model_data=model_data, role='SageMakerRole',
3738
framework_version=tf_full_version,
3839
sagemaker_session=sagemaker_session)
3940
predictor = model.deploy(1, instance_type, endpoint_name=endpoint_name)
4041
yield predictor
4142

4243

43-
# @pytest.mark.continuous_testing
44-
# @pytest.mark.regional_testing
45-
def test_predict(tfs_predictor):
44+
@pytest.mark.continuous_testing
45+
@pytest.mark.regional_testing
46+
def test_predict(tfs_predictor, instance_type):
47+
if ('p3' in instance_type) and (
48+
tests.integ.REGION in tests.integ.HOSTING_P3_UNAVAILABLE_REGIONS):
49+
pytest.skip('no ml.p3 instances in this region')
50+
4651
input_data = {'instances': [1.0, 2.0, 5.0]}
4752
expected_result = {'predictions': [3.5, 4.0, 5.5]}
4853

0 commit comments

Comments
 (0)